mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -19,7 +19,8 @@ from vllm.v1.core.kv_cache_utils import (FreeKVCacheBlockQueue, KVCacheBlock,
|
||||
hash_request_tokens,
|
||||
unify_kv_cache_configs)
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
KVCacheGroupSpec, KVCacheTensor,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||
from vllm.v1.request import Request
|
||||
|
||||
@ -54,12 +55,14 @@ def new_kv_cache_spec(block_size=16,
|
||||
num_kv_heads=2,
|
||||
head_size=64,
|
||||
dtype=torch.float32,
|
||||
use_mla=False):
|
||||
use_mla=False,
|
||||
sliding_window=None):
|
||||
return FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
use_mla=use_mla)
|
||||
use_mla=use_mla,
|
||||
sliding_window=sliding_window)
|
||||
|
||||
|
||||
def test_none_hash(monkeypatch):
|
||||
@ -492,6 +495,68 @@ def test_unify_kv_cache_configs():
|
||||
unify_kv_cache_configs(diff_kv_cache_config)
|
||||
|
||||
|
||||
def test_merge_kv_cache_spec():
|
||||
same_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
]
|
||||
merged_layer_spec = same_layer_specs[0].merge(same_layer_specs)
|
||||
assert merged_layer_spec.block_size == 16
|
||||
assert merged_layer_spec.num_kv_heads == 32
|
||||
assert merged_layer_spec.head_size == 64
|
||||
assert merged_layer_spec.dtype == torch.float32
|
||||
assert merged_layer_spec.sliding_window is None
|
||||
|
||||
different_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=16),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
different_layer_specs[0].merge(different_layer_specs)
|
||||
|
||||
full_spec = new_kv_cache_spec(num_kv_heads=32)
|
||||
different_type_layer_specs = [
|
||||
full_spec,
|
||||
SlidingWindowSpec(
|
||||
block_size=full_spec.block_size,
|
||||
num_kv_heads=full_spec.num_kv_heads,
|
||||
head_size=full_spec.head_size,
|
||||
dtype=full_spec.dtype,
|
||||
use_mla=full_spec.use_mla,
|
||||
sliding_window=1,
|
||||
),
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
different_type_layer_specs[0].merge(different_type_layer_specs)
|
||||
with pytest.raises(AssertionError):
|
||||
different_type_layer_specs[1].merge(different_type_layer_specs)
|
||||
|
||||
different_sliding_window_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=2),
|
||||
]
|
||||
with pytest.raises(ValueError):
|
||||
different_sliding_window_layer_specs[0].merge(
|
||||
different_sliding_window_layer_specs)
|
||||
|
||||
same_sliding_window_layer_specs = [
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
]
|
||||
merged_layer_spec = same_sliding_window_layer_specs[0].merge(
|
||||
same_sliding_window_layer_specs)
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
same_sliding_window_layer_spec_with_none = [
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=1),
|
||||
new_kv_cache_spec(num_kv_heads=32, sliding_window=None),
|
||||
]
|
||||
merged_layer_spec = same_sliding_window_layer_spec_with_none[0].merge(
|
||||
same_sliding_window_layer_spec_with_none)
|
||||
assert merged_layer_spec.sliding_window == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "want_estimated_max_len"), [
|
||||
("Qwen/Qwen1.5-7B", 16385, 16384),
|
||||
|
@ -84,7 +84,7 @@ def test_prefill(hash_algo):
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
# Check full block metadata
|
||||
parent_block_hash = None
|
||||
@ -107,13 +107,13 @@ def test_prefill(hash_algo):
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -141,13 +141,13 @@ def test_prefill(hash_algo):
|
||||
req2 = make_request("2", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert len(manager.req_to_block_hashes[req2.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req2, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [6]
|
||||
assert blocks.get_block_ids() == [[6]]
|
||||
|
||||
# Although we only have 6 free blocks, we have 8 blocks in
|
||||
# the free block queue due to lazy removal.
|
||||
@ -171,7 +171,7 @@ def test_prefill(hash_algo):
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
# This block ID order also checks the eviction order.
|
||||
assert blocks.get_block_ids() == [7, 8, 9, 10, 4, 5, 6, 3, 2, 1]
|
||||
assert blocks.get_block_ids() == [[7, 8, 9, 10, 4, 5, 6, 3, 2, 1]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 0
|
||||
assert manager.block_pool.free_block_queue.free_list_head is None
|
||||
assert manager.block_pool.free_block_queue.free_list_tail is None
|
||||
@ -208,7 +208,7 @@ def test_prefill_plp():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0_block_hashes = [b.block_hash for b in blocks.blocks]
|
||||
|
||||
# Check full block metadata
|
||||
@ -233,13 +233,13 @@ def test_prefill_plp():
|
||||
req1 = make_request("1", common_token_ids + unique_token_ids)
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
|
||||
assert len(manager.req_to_block_hashes[req1.request_id]) == 3
|
||||
assert computed_blocks.get_block_ids() == [1, 2, 3]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2, 3]]
|
||||
assert num_computed_tokens == 3 * 16
|
||||
num_new_tokens = 53 - 3 * 16
|
||||
blocks = manager.allocate_slots(req1, num_new_tokens,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
for block in computed_blocks.blocks:
|
||||
assert block.ref_cnt == 2
|
||||
|
||||
@ -277,11 +277,11 @@ def test_prefill_plp():
|
||||
block_ids = blocks.get_block_ids()
|
||||
# Duplicate cached blocks have different ids but same hashes vs request #0
|
||||
assert [b.block_hash for b in blocks.blocks] == req0_block_hashes
|
||||
assert block_ids != [1, 2, 3, 4]
|
||||
assert block_ids != [[1, 2, 3, 4]]
|
||||
|
||||
# Request #2 block hashes are valid since request #0 hashes are.
|
||||
# Check block reference counts.
|
||||
for block_id in block_ids:
|
||||
for block_id in block_ids[0]:
|
||||
assert manager.block_pool.blocks[block_id].ref_cnt == 1
|
||||
|
||||
manager.free(req2)
|
||||
@ -307,7 +307,7 @@ def test_decode():
|
||||
blocks = manager.allocate_slots(req0, 55,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
req0.num_computed_tokens = 55
|
||||
@ -379,12 +379,12 @@ def test_evict():
|
||||
# Touch the first 2 blocks.
|
||||
req2 = make_request("2", list(range(2 * 16 + 3)))
|
||||
computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
|
||||
assert computed_blocks.get_block_ids() == [1, 2]
|
||||
assert computed_blocks.get_block_ids() == [[1, 2]]
|
||||
assert num_computed_tokens == 2 * 16
|
||||
blocks = manager.allocate_slots(req2, 3,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [10]
|
||||
assert blocks.get_block_ids() == [[10]]
|
||||
assert manager.block_pool.free_block_queue.num_free_blocks == 7
|
||||
|
||||
|
||||
@ -625,7 +625,7 @@ def test_mm_prefix_caching():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@ -686,7 +686,7 @@ def test_cache_key_salting():
|
||||
blocks = manager.allocate_slots(req0, 59,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
req0.num_computed_tokens = 59
|
||||
|
||||
# Append slots without allocating a new block.
|
||||
@ -797,7 +797,7 @@ def test_reset_prefix_cache():
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
req0 = make_request("0", all_token_ids)
|
||||
blocks = manager.allocate_slots(req0, 55)
|
||||
assert blocks.get_block_ids() == [1, 2, 3, 4]
|
||||
assert blocks.get_block_ids() == [[1, 2, 3, 4]]
|
||||
|
||||
unique_token_ids = [4] * 7
|
||||
all_token_ids = full_block_token_ids + unique_token_ids
|
||||
@ -808,7 +808,7 @@ def test_reset_prefix_cache():
|
||||
blocks = manager.allocate_slots(req1, 7,
|
||||
len(computed_blocks.blocks) * 16,
|
||||
computed_blocks)
|
||||
assert blocks.get_block_ids() == [5]
|
||||
assert blocks.get_block_ids() == [[5]]
|
||||
|
||||
# Failed to reset prefix cache because some blocks are not freed yet.
|
||||
assert not manager.reset_prefix_cache()
|
||||
|
@ -9,9 +9,11 @@ import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import (BlockTable, CachedRequestState,
|
||||
InputBatch)
|
||||
from vllm.v1.worker.block_table import BlockTable, MultiGroupBlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
VOCAB_SIZE = 1024
|
||||
NUM_OUTPUT_TOKENS = 20
|
||||
@ -22,6 +24,27 @@ CUDA_DEVICES = [
|
||||
MAX_NUM_PROMPT_TOKENS = 64
|
||||
|
||||
|
||||
def get_kv_cache_config() -> KVCacheConfig:
|
||||
return KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=1,
|
||||
num_kv_heads=1,
|
||||
head_size=16,
|
||||
dtype=torch.float16,
|
||||
use_mla=False,
|
||||
),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _compare_objs(obj1, obj2):
|
||||
attrs = inspect.getmembers(obj1, lambda a: not (inspect.isroutine(a)))
|
||||
attr_names = set([
|
||||
@ -41,6 +64,10 @@ def _compare_objs(obj1, obj2):
|
||||
elif isinstance(a, np.ndarray):
|
||||
if np.allclose(a, b):
|
||||
is_same = True
|
||||
elif isinstance(a, MultiGroupBlockTable):
|
||||
for a_i, b_i in zip(a.block_tables, b.block_tables):
|
||||
_compare_objs(a_i, b_i)
|
||||
is_same = True
|
||||
elif isinstance(a, (BlockTable, SamplingMetadata)):
|
||||
_compare_objs(a, b)
|
||||
is_same = True # if we make it here must be same
|
||||
@ -198,7 +225,7 @@ def _construct_cached_request_state(req_id_suffix: int):
|
||||
sampling_params=_create_sampling_params(),
|
||||
mm_inputs=[],
|
||||
mm_positions=[],
|
||||
block_ids=[],
|
||||
block_ids=[[]],
|
||||
generator=None,
|
||||
num_computed_tokens=len(output_token_ids),
|
||||
output_token_ids=output_token_ids,
|
||||
@ -220,11 +247,11 @@ def test_sampling_metadata_in_input_batch(device: str, batch_size: int):
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
)
|
||||
reqs: list[CachedRequestState] = []
|
||||
req_id_reqs = {}
|
||||
@ -310,20 +337,20 @@ def test_swap_states_in_input_batch(device: str, batch_size: int,
|
||||
input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
)
|
||||
ref_input_batch: InputBatch = InputBatch(
|
||||
max_num_reqs=batch_size,
|
||||
max_model_len=1024,
|
||||
max_num_blocks_per_req=10,
|
||||
max_num_batched_tokens=1024,
|
||||
device=torch.device(device),
|
||||
pin_memory=is_pin_memory_available(),
|
||||
vocab_size=1024,
|
||||
block_size=1,
|
||||
)
|
||||
|
||||
reqs: list[CachedRequestState] = []
|
||||
|
@ -1,15 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import weakref
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheTensor)
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
|
||||
@ -17,13 +18,34 @@ def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
kv_cache_spec = FullAttentionSpec(block_size=16,
|
||||
num_kv_heads=1,
|
||||
head_size=64,
|
||||
dtype=torch.float16,
|
||||
use_mla=False)
|
||||
runner.attn_metadata_builder = runner.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(runner), kv_cache_spec, runner.input_batch.block_table)
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
))
|
||||
])
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=runner.max_model_len,
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.block_size,
|
||||
)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -48,10 +70,12 @@ def model_runner():
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
)
|
||||
parallel_config = ParallelConfig()
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
@ -73,7 +97,7 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
mm_hashes=[],
|
||||
mm_positions=[],
|
||||
sampling_params=SamplingParams(),
|
||||
block_ids=[0],
|
||||
block_ids=[[0]],
|
||||
num_computed_tokens=0,
|
||||
lora_request=None,
|
||||
))
|
||||
@ -111,13 +135,14 @@ def _is_sampling_metadata_changed(model_runner,
|
||||
|
||||
def _is_req_state_block_table_match(model_runner, req_id: str) -> bool:
|
||||
req_index = model_runner.input_batch.req_id_to_index[req_id]
|
||||
block_table = model_runner.input_batch.block_table
|
||||
block_table = model_runner.input_batch.block_table[0]
|
||||
req_state = model_runner.requests[req_id]
|
||||
if block_table.num_blocks_per_row[req_index] != len(req_state.block_ids):
|
||||
if block_table.num_blocks_per_row[req_index] != len(
|
||||
req_state.block_ids[0]):
|
||||
return False
|
||||
num_blocks = block_table.num_blocks_per_row[req_index]
|
||||
return (block_table.block_table_np[req_index, :num_blocks] ==
|
||||
req_state.block_ids).all()
|
||||
req_state.block_ids[0]).all()
|
||||
|
||||
|
||||
def test_update_states_new_request(model_runner):
|
||||
@ -200,7 +225,7 @@ def test_update_states_request_resumed(model_runner):
|
||||
req_id=req_id,
|
||||
resumed_from_preemption=False,
|
||||
new_token_ids=[],
|
||||
new_block_ids=[],
|
||||
new_block_ids=[[]],
|
||||
num_computed_tokens=0,
|
||||
)
|
||||
|
||||
|
@ -288,7 +288,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
for new_req in scheduler_output.scheduled_new_reqs:
|
||||
if new_req.req_id in self._requests_need_load:
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=False)
|
||||
total_need_load += 1
|
||||
@ -299,7 +299,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
# the original prompt tokens.
|
||||
if not self._found_match_for_request(new_req):
|
||||
meta.add_request(token_ids=new_req.prompt_token_ids,
|
||||
block_ids=new_req.block_ids,
|
||||
block_ids=new_req.block_ids[0],
|
||||
block_size=self._block_size,
|
||||
is_store=True)
|
||||
|
||||
@ -319,7 +319,7 @@ class SharedStorageConnector(KVConnectorBase_V1):
|
||||
|
||||
# NOTE(rob): For resumed req, new_block_ids is all
|
||||
# of the block_ids for the request.
|
||||
block_ids = cached_req.new_block_ids
|
||||
block_ids = cached_req.new_block_ids[0]
|
||||
|
||||
meta.add_request(token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
|
@ -69,13 +69,13 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_model_len = self.runner.model_config.max_model_len
|
||||
assert max_model_len == 32768,\
|
||||
"AITER MLA requires max_model_len=32768"
|
||||
assert self.runner.block_size == 1, "AITER MLA" \
|
||||
assert self.kv_cache_spec.block_size == 1, "AITER MLA" \
|
||||
"only supports block size 1."
|
||||
|
||||
def _get_paged_kv_tensors(
|
||||
self, block_table: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
page_size = self.runner.block_size
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens + page_size - 1) // page_size
|
||||
device = self.runner.device
|
||||
|
||||
|
@ -32,9 +32,16 @@ class KVCacheBlocks:
|
||||
"""Creates a new KVCacheBlocks instance with no blocks."""
|
||||
return cls([])
|
||||
|
||||
def get_block_ids(self) -> list[int]:
|
||||
"""Converts the KVCacheBlocks instance to a list of block IDs."""
|
||||
return [block.block_id for block in self.blocks]
|
||||
def get_block_ids(self) -> list[list[int]]:
|
||||
"""
|
||||
Converts the KVCacheBlocks instance to block_ids.
|
||||
|
||||
Returns:
|
||||
list[list[int]]: A two-level list where
|
||||
* the outer list corresponds to KV cache groups (only 1 group now)
|
||||
* each inner list contains the block_ids of the blocks in that group
|
||||
"""
|
||||
return [[block.block_id for block in self.blocks]]
|
||||
|
||||
def get_unhashed_block_ids(self) -> list[int]:
|
||||
"""Get block_ids of unhashed blocks from KVCacheBlocks instance."""
|
||||
@ -300,9 +307,9 @@ class KVCacheManager:
|
||||
self,
|
||||
request: Request,
|
||||
num_running_requests: int,
|
||||
) -> int:
|
||||
) -> list[int]:
|
||||
"""Calculate the number of common prefix blocks shared by all requests
|
||||
in the RUNNING state.
|
||||
in the RUNNING state for each kv cache group.
|
||||
|
||||
The function determines this by selecting any request and iterating
|
||||
through its blocks. A block is considered a common prefix block if its
|
||||
@ -332,11 +339,14 @@ class KVCacheManager:
|
||||
requests in the current step.
|
||||
|
||||
Returns:
|
||||
int: The number of common prefix blocks.
|
||||
list[int]: The number of common prefix blocks for each kv cache
|
||||
group.
|
||||
"""
|
||||
assert request.status == RequestStatus.RUNNING
|
||||
return self.single_type_manager.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
return [
|
||||
self.single_type_manager.get_num_common_prefix_blocks(
|
||||
request.request_id, num_running_requests)
|
||||
]
|
||||
|
||||
def free_block_hashes(self, request: Request) -> None:
|
||||
"""Discard the block hashes for the request.
|
||||
@ -354,10 +364,8 @@ class KVCacheManager:
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
||||
def get_block_ids(self, request_id: str) -> list[int]:
|
||||
def get_block_ids(self, request_id: str) -> list[list[int]]:
|
||||
"""Get the block ids of a request."""
|
||||
assert request_id in self.single_type_manager.req_to_blocks
|
||||
return [
|
||||
block.block_id
|
||||
for block in self.single_type_manager.req_to_blocks[request_id]
|
||||
]
|
||||
return KVCacheBlocks(self.single_type_manager.req_to_blocks[request_id]
|
||||
).get_block_ids()
|
||||
|
@ -577,14 +577,12 @@ def create_kv_cache_group_specs(
|
||||
"""
|
||||
kv_cache_groups = []
|
||||
for layer_names_one_group in grouped_layer_names:
|
||||
layer_spec = kv_cache_spec[layer_names_one_group[0]]
|
||||
assert all(
|
||||
kv_cache_spec[layer_name] == layer_spec
|
||||
for layer_name in layer_names_one_group[1:]), (
|
||||
"All layers in the same KV cache group must share the same "
|
||||
"KVCacheSpec.")
|
||||
layer_specs = [
|
||||
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
|
||||
]
|
||||
merged_layer_spec = layer_specs[0].merge(layer_specs)
|
||||
kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names_one_group, layer_spec))
|
||||
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
|
||||
return kv_cache_groups
|
||||
|
||||
|
||||
@ -683,6 +681,7 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]):
|
||||
head_size=spec.head_size,
|
||||
dtype=spec.dtype,
|
||||
use_mla=spec.use_mla,
|
||||
sliding_window=spec.sliding_window,
|
||||
)
|
||||
|
||||
|
||||
|
@ -26,7 +26,7 @@ class NewRequestData:
|
||||
mm_hashes: list[str]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: list[int]
|
||||
block_ids: list[list[int]]
|
||||
num_computed_tokens: int
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
@ -34,7 +34,7 @@ class NewRequestData:
|
||||
def from_request(
|
||||
cls,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
block_ids: list[list[int]],
|
||||
) -> NewRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
@ -85,7 +85,7 @@ class CachedRequestData:
|
||||
# request's block IDs instead of appending to the existing block IDs.
|
||||
resumed_from_preemption: bool
|
||||
new_token_ids: list[int]
|
||||
new_block_ids: list[int]
|
||||
new_block_ids: list[list[int]]
|
||||
num_computed_tokens: int
|
||||
|
||||
@classmethod
|
||||
@ -94,7 +94,7 @@ class CachedRequestData:
|
||||
request: Request,
|
||||
resumed_from_preemption: bool,
|
||||
new_token_ids: list[int],
|
||||
new_block_ids: list[int],
|
||||
new_block_ids: list[list[int]],
|
||||
) -> CachedRequestData:
|
||||
return cls(
|
||||
req_id=request.request_id,
|
||||
@ -131,9 +131,9 @@ class SchedulerOutput:
|
||||
# E.g., if a request has [0, 1], it could mean the vision encoder needs
|
||||
# to process that the request's 0-th and 1-th images in the current step.
|
||||
scheduled_encoder_inputs: dict[str, list[int]]
|
||||
# Number of common prefix blocks for all requests.
|
||||
# Number of common prefix blocks for all requests in each KV cache group.
|
||||
# This can be used for cascade attention.
|
||||
num_common_prefix_blocks: int
|
||||
num_common_prefix_blocks: list[int]
|
||||
|
||||
# Request IDs that are finished in between the previous and the current
|
||||
# steps. This is used to notify the workers about the finished requests
|
||||
|
@ -173,7 +173,7 @@ class Scheduler(SchedulerInterface):
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
|
||||
req_to_new_block_ids: dict[str, list[int]] = {}
|
||||
req_to_new_block_ids: dict[str, list[list[int]]] = {}
|
||||
num_scheduled_tokens: dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
# Encoder-related.
|
||||
@ -486,7 +486,8 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
# Get the longest common prefix among all requests in the running queue.
|
||||
# This can be potentially used for cascade attention.
|
||||
num_common_prefix_blocks = 0
|
||||
num_common_prefix_blocks = [0] * len(
|
||||
self.kv_cache_config.kv_cache_groups)
|
||||
if self.running:
|
||||
any_request = self.running[0]
|
||||
num_common_prefix_blocks = (
|
||||
@ -573,7 +574,7 @@ class Scheduler(SchedulerInterface):
|
||||
request: Request,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_spec_tokens: int,
|
||||
new_block_ids: list[int],
|
||||
new_block_ids: list[list[int]],
|
||||
resumed_from_preemption: bool,
|
||||
) -> CachedRequestData:
|
||||
# OPTIMIZATION: Cache the CachedRequestData objects to avoid creating
|
||||
@ -949,7 +950,9 @@ class Scheduler(SchedulerInterface):
|
||||
"""
|
||||
if self.connector is None:
|
||||
return False, None
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
assert len(self.kv_cache_config.kv_cache_groups
|
||||
) == 1, "KV connector only supports one KV cache group now"
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
|
||||
return self.connector.request_finished(request, block_ids)
|
||||
|
||||
def _update_waiting_for_remote_kv(self, request: Request) -> bool:
|
||||
@ -966,9 +969,10 @@ class Scheduler(SchedulerInterface):
|
||||
"""
|
||||
if request.request_id not in self.finished_recving_kv_req_ids:
|
||||
return False
|
||||
|
||||
assert len(self.kv_cache_config.kv_cache_groups
|
||||
) == 1, "KV connector only supports one KV cache group now"
|
||||
# Now that the blocks are ready, actually cache them.
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)
|
||||
block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0]
|
||||
num_computed_tokens = len(block_ids) * self.block_size
|
||||
if num_computed_tokens == request.num_tokens:
|
||||
num_computed_tokens -= 1
|
||||
|
@ -1,8 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
@ -53,6 +56,16 @@ class KVCacheSpec:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of KVCacheSpec objects into a single KVCacheSpec object.
|
||||
"""
|
||||
assert all(spec.type_id == specs[0].type_id for spec in specs[1:]), (
|
||||
"All layers in the same KV cache group must share the same "
|
||||
"type_id.")
|
||||
return copy.deepcopy(specs[0])
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
@ -71,6 +84,16 @@ class AttentionSpec(KVCacheSpec):
|
||||
|
||||
@dataclass
|
||||
class FullAttentionSpec(AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
@property
|
||||
def type_id(self) -> str:
|
||||
@ -80,6 +103,25 @@ class FullAttentionSpec(AttentionSpec):
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
merged_spec = super().merge(specs)
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
if len(sliding_window) == 0:
|
||||
merged_spec.sliding_window = None
|
||||
elif len(sliding_window) == 1:
|
||||
merged_spec.sliding_window = sliding_window.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All sliding window layers in the same KV cache group "
|
||||
"must have the same window size.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlidingWindowSpec(AttentionSpec):
|
||||
|
@ -4,6 +4,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -96,3 +97,43 @@ class BlockTable:
|
||||
def get_numpy_array(self) -> np.ndarray:
|
||||
"""Returns the numpy array of the block table."""
|
||||
return self.block_table_np
|
||||
|
||||
|
||||
class MultiGroupBlockTable:
|
||||
"""The BlockTables for each KV cache group."""
|
||||
|
||||
def __init__(self, max_num_reqs: int, max_model_len: int,
|
||||
max_num_batched_tokens: int, pin_memory: bool,
|
||||
device: torch.device, block_size: int) -> None:
|
||||
self.block_tables = [
|
||||
BlockTable(max_num_reqs, cdiv(max_model_len, block_size),
|
||||
max_num_batched_tokens, pin_memory, device)
|
||||
]
|
||||
|
||||
def append_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.append_row(block_ids[i], row_idx)
|
||||
|
||||
def add_row(self, block_ids: list[list[int]], row_idx: int) -> None:
|
||||
for i, block_table in enumerate(self.block_tables):
|
||||
block_table.add_row(block_ids[i], row_idx)
|
||||
|
||||
def move_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.move_row(src, tgt)
|
||||
|
||||
def swap_row(self, src: int, tgt: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.swap_row(src, tgt)
|
||||
|
||||
def commit(self, num_reqs: int) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.commit(num_reqs)
|
||||
|
||||
def clear(self) -> None:
|
||||
for block_table in self.block_tables:
|
||||
block_table.clear()
|
||||
|
||||
def __getitem__(self, idx: int) -> "BlockTable":
|
||||
"""Returns the BlockTable for the i-th KV cache group."""
|
||||
return self.block_tables[idx]
|
||||
|
@ -14,7 +14,7 @@ from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
|
||||
@ -29,7 +29,7 @@ class CachedRequestState:
|
||||
sampling_params: SamplingParams
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
block_ids: list[int]
|
||||
block_ids: list[list[int]]
|
||||
num_computed_tokens: int
|
||||
output_token_ids: list[int]
|
||||
|
||||
@ -58,15 +58,14 @@ class InputBatch:
|
||||
self,
|
||||
max_num_reqs: int,
|
||||
max_model_len: int,
|
||||
max_num_blocks_per_req: int,
|
||||
max_num_batched_tokens: int,
|
||||
device: torch.device,
|
||||
pin_memory: bool,
|
||||
vocab_size: int,
|
||||
block_size: int,
|
||||
):
|
||||
self.max_num_reqs = max_num_reqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_num_blocks_per_req = max_num_blocks_per_req
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.device = device
|
||||
self.pin_memory = pin_memory
|
||||
@ -99,12 +98,13 @@ class InputBatch:
|
||||
self.num_computed_tokens_cpu_tensor.numpy()
|
||||
|
||||
# Block table.
|
||||
self.block_table = BlockTable(
|
||||
self.block_table = MultiGroupBlockTable(
|
||||
max_num_reqs=max_num_reqs,
|
||||
max_num_blocks_per_req=max_num_blocks_per_req,
|
||||
max_model_len=max_model_len,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
pin_memory=pin_memory,
|
||||
device=device,
|
||||
block_size=block_size,
|
||||
)
|
||||
|
||||
# Sampling-related.
|
||||
|
@ -12,6 +12,8 @@ import torch.distributed
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.utils.fa_utils import get_flash_attn_version
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
@ -32,8 +34,8 @@ from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LayerBlockType, LazyLoader, cdiv,
|
||||
check_use_alibi, is_pin_memory_available)
|
||||
GiB_bytes, LazyLoader, cdiv, check_use_alibi,
|
||||
is_pin_memory_available)
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
|
||||
@ -51,6 +53,7 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
|
||||
@ -103,59 +106,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
cache_config.cache_dtype]
|
||||
|
||||
# NOTE(woosuk): sliding_window is None for models with interleaved
|
||||
# attention. Use interleaved_sliding_window instead.
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.interleaved_sliding_window = getattr(
|
||||
model_config.hf_text_config, "interleaved_sliding_window", None)
|
||||
self.window_size = (self.sliding_window
|
||||
or self.interleaved_sliding_window)
|
||||
|
||||
self.is_multimodal_model = model_config.is_multimodal_model
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
|
||||
self.max_num_tokens = scheduler_config.max_num_batched_tokens
|
||||
self.max_num_reqs = scheduler_config.max_num_seqs
|
||||
|
||||
# Model-related.
|
||||
self.num_attn_layers = model_config.get_num_layers_by_block_type(
|
||||
parallel_config, LayerBlockType.attention)
|
||||
self.num_query_heads = model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.num_kv_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.hidden_size = model_config.get_hidden_size()
|
||||
self.attention_chunk_size = model_config.attention_chunk_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.head_size,
|
||||
self.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
if self.attn_backend is None:
|
||||
error_msg = (
|
||||
f"Error with get_att_backend: {self.head_size=}, "
|
||||
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{self.model_config.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = self.attn_backend.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is {attn_backend_name}, "
|
||||
f"FlashAttention version is {flash_attn_version}.")
|
||||
|
||||
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
|
||||
|
||||
# Multi-modal data support
|
||||
@ -177,8 +138,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# self.model: nn.Module # Set after load_model
|
||||
# Initialize in initialize_kv_cache
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
self.attn_metadata_builders: list[AttentionMetadataBuilder] = []
|
||||
self.attn_backends: list[type[AttentionBackend]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# self.attn_metadata_builder: type[AttentionMetadataBuilder]
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
@ -207,15 +170,15 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=self.cache_config.block_size,
|
||||
)
|
||||
|
||||
self.use_cuda_graph = (self.vllm_config.compilation_config.level
|
||||
@ -311,6 +274,31 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
pin_memory=self.pin_memory)
|
||||
self.seq_lens_np = self.seq_lens_cpu.numpy()
|
||||
|
||||
def _may_reorder_batch(self, scheduler_output: "SchedulerOutput") -> bool:
|
||||
"""
|
||||
Update the order of requests in the batch based on the attention
|
||||
backend's needs. For example, some attention backends (namely MLA) may
|
||||
want to separate requests based on if the attention computation will be
|
||||
compute-bound or memory-bound.
|
||||
|
||||
Args:
|
||||
scheduler_output: The scheduler output.
|
||||
|
||||
Returns:
|
||||
True if the batch was reordered, False otherwise.
|
||||
"""
|
||||
batch_reordered = self.attn_metadata_builders[0].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
|
||||
# For models with multiple KV cache groups, the groups should agree on
|
||||
# the same order of requests. We ensure this by only allowing the first
|
||||
# group to reorder the batch and asserting that all other groups do not
|
||||
# reorder the batch.
|
||||
for i in range(1, len(self.kv_cache_config.kv_cache_groups)):
|
||||
assert not self.attn_metadata_builders[i].reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
return batch_reordered
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@ -447,7 +435,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Update the block IDs.
|
||||
if not req_data.resumed_from_preemption:
|
||||
# Append the new blocks to the existing block IDs.
|
||||
req_state.block_ids.extend(req_data.new_block_ids)
|
||||
for i in range(len(self.kv_cache_config.kv_cache_groups)):
|
||||
req_state.block_ids[i].extend(req_data.new_block_ids[i])
|
||||
else:
|
||||
# The request is resumed from preemption.
|
||||
# Replace the existing block IDs with the new ones.
|
||||
@ -505,11 +494,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if removed_req_indices:
|
||||
self.input_batch.condense(removed_req_indices)
|
||||
|
||||
# Some attention backends (namely MLA) may want to separate requests
|
||||
# based on if the attention computation will be compute-bound or
|
||||
# memory-bound. This gives them a hook to do that.
|
||||
batch_reordered = self.attn_metadata_builder.reorder_batch(
|
||||
self.input_batch, scheduler_output)
|
||||
batch_reordered = self._may_reorder_batch(scheduler_output)
|
||||
|
||||
if batch_changed or batch_reordered:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
@ -577,21 +562,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.from_numpy(token_indices),
|
||||
out=self.input_ids_cpu[:total_num_scheduled_tokens])
|
||||
|
||||
# Calculate the slot mapping.
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size` here
|
||||
# because M (max_model_len) is not necessarily divisible by block_size.
|
||||
block_table_indices = (req_indices * self.max_num_blocks_per_req +
|
||||
positions_np // self.block_size)
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.input_batch.block_table.
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
# Calculate the slot mapping for each KV cache group.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
block_size = kv_cache_group_spec.kv_cache_spec.block_size
|
||||
block_table: BlockTable = self.input_batch.block_table[
|
||||
kv_cache_group_id]
|
||||
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
|
||||
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
|
||||
# where K is the max_num_blocks_per_req and the block size is 2.
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
block_table_indices = (
|
||||
req_indices * block_table.max_num_blocks_per_req +
|
||||
positions_np // block_size)
|
||||
block_table_cpu = block_table.get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten(
|
||||
)[block_table_indices].numpy()
|
||||
block_offsets = positions_np % block_size
|
||||
np.add(
|
||||
block_numbers * block_size,
|
||||
block_offsets,
|
||||
out=block_table.slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
self.query_start_loc_np[0] = 0
|
||||
@ -633,10 +626,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_metadata: dict[str, FlashAttentionMetadata] = {}
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
# in the same group share the same metadata.
|
||||
# NOTE(Chen): there is exactly one KV cache group that contains all
|
||||
# attetnion layers in the model for now, so the current logic for
|
||||
# getting attn_metadata is not related to kv_cache_group information.
|
||||
# Will extend this part to support multiple KV cache groups later.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
|
||||
@ -645,15 +634,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.cascade_attn_enabled:
|
||||
common_prefix_len = self._compute_cascade_attn_prefix_len(
|
||||
num_scheduled_tokens,
|
||||
scheduler_output.num_common_prefix_blocks,
|
||||
scheduler_output.
|
||||
num_common_prefix_blocks[kv_cache_group_id],
|
||||
kv_cache_group_spec.kv_cache_spec,
|
||||
self.attn_metadata_builders[kv_cache_group_id],
|
||||
)
|
||||
|
||||
attn_metadata_i = self.attn_metadata_builder.build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata)
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
common_prefix_len=common_prefix_len,
|
||||
common_attn_metadata=common_attn_metadata))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
@ -691,6 +684,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_scheduled_tokens: np.ndarray,
|
||||
num_common_prefix_blocks: int,
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
attn_metadata_builder: AttentionMetadataBuilder,
|
||||
) -> int:
|
||||
"""Compute the length of the common prefix for cascade attention.
|
||||
|
||||
@ -709,7 +704,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
Returns:
|
||||
int: Length of common prefix in tokens.
|
||||
"""
|
||||
common_prefix_len = num_common_prefix_blocks * self.block_size
|
||||
common_prefix_len = num_common_prefix_blocks * kv_cache_spec.block_size
|
||||
if common_prefix_len == 0:
|
||||
# Common case.
|
||||
return 0
|
||||
@ -758,15 +753,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_prefix_len,
|
||||
self.input_batch.num_computed_tokens_cpu[:num_reqs].min())
|
||||
# common_prefix_len should be a multiple of the block size.
|
||||
common_prefix_len = (common_prefix_len // self.block_size *
|
||||
self.block_size)
|
||||
use_cascade = self.attn_metadata_builder.use_cascade_attention(
|
||||
common_prefix_len = (common_prefix_len // kv_cache_spec.block_size *
|
||||
kv_cache_spec.block_size)
|
||||
use_sliding_window = (isinstance(kv_cache_spec, SlidingWindowSpec) or
|
||||
(isinstance(kv_cache_spec, FullAttentionSpec)
|
||||
and kv_cache_spec.sliding_window is not None))
|
||||
assert isinstance(kv_cache_spec, AttentionSpec)
|
||||
use_cascade = attn_metadata_builder.use_cascade_attention(
|
||||
common_prefix_len=common_prefix_len,
|
||||
query_lens=num_scheduled_tokens,
|
||||
num_query_heads=self.num_query_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
num_kv_heads=kv_cache_spec.num_kv_heads,
|
||||
use_alibi=self.use_alibi,
|
||||
use_sliding_window=self.window_size is not None,
|
||||
use_sliding_window=use_sliding_window,
|
||||
num_sms=self.num_sms,
|
||||
)
|
||||
return common_prefix_len if use_cascade else 0
|
||||
@ -1661,7 +1660,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=np.int32)
|
||||
|
||||
if skip_attn:
|
||||
attn_metadata = None
|
||||
attn_metadata: Optional[dict[str, FlashAttentionMetadata]] = None
|
||||
else:
|
||||
query_start_loc = self.query_start_loc[:num_reqs + 1]
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
@ -1669,13 +1668,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
|
||||
attn_metadata = self.attn_metadata_builder.build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
)
|
||||
attn_metadata = {}
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
attn_metadata_i = (
|
||||
self.attn_metadata_builders[kv_cache_group_id].build(
|
||||
num_reqs=num_tokens,
|
||||
num_actual_tokens=num_tokens,
|
||||
max_query_len=num_tokens,
|
||||
common_prefix_len=0,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
))
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
|
||||
with self.maybe_dummy_run_with_lora(self.lora_config,
|
||||
num_scheduled_tokens):
|
||||
@ -1909,6 +1914,56 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info("Graph capturing finished in %.0f secs, took %.2f GiB",
|
||||
elapsed_time, cuda_graph_size / (1 << 30))
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
"""
|
||||
assert len(self.attn_backends) == 0 and len(
|
||||
self.attn_metadata_builders
|
||||
) == 0, "Attention backends are already initialized"
|
||||
for i, kv_cache_group_spec in enumerate(
|
||||
kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
|
||||
if not isinstance(kv_cache_spec, AttentionSpec):
|
||||
raise NotImplementedError(
|
||||
"Only AttentionSpec is supported for now.")
|
||||
attn_backend_i = get_attn_backend(
|
||||
kv_cache_spec.head_size,
|
||||
self.dtype,
|
||||
kv_cache_spec.dtype,
|
||||
kv_cache_spec.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=kv_cache_spec.use_mla,
|
||||
)
|
||||
if attn_backend_i is None:
|
||||
error_msg = (
|
||||
f"Error with get_attn_backend: {kv_cache_spec.head_size=}, "
|
||||
f"{self.dtype=}, {kv_cache_spec.dtype=}, "
|
||||
f"{kv_cache_spec.block_size=}, "
|
||||
f"{self.model_config.is_attention_free=}, "
|
||||
f"{kv_cache_spec.use_mla=}")
|
||||
logger.error(error_msg)
|
||||
raise NotImplementedError(
|
||||
"Non-Attention backend is not supported by V1 "
|
||||
"GPUModelRunner.")
|
||||
|
||||
if self.vllm_config.compilation_config.full_cuda_graph:
|
||||
attn_backend_name = attn_backend_i.__name__
|
||||
flash_attn_version = get_flash_attn_version()
|
||||
if attn_backend_name != "FlashAttentionBackend" or \
|
||||
flash_attn_version != 3:
|
||||
raise ValueError(
|
||||
f"full_cuda_graph is only supported with "
|
||||
f"FA3. Current attention backend is "
|
||||
f"{attn_backend_name}, FlashAttention version is "
|
||||
f"{flash_attn_version}.")
|
||||
|
||||
block_table_i = self.input_batch.block_table[i]
|
||||
attn_metadata_builder_i = attn_backend_i.get_builder_cls()(
|
||||
weakref.proxy(self), kv_cache_spec, block_table_i)
|
||||
self.attn_backends.append(attn_backend_i)
|
||||
self.attn_metadata_builders.append(attn_metadata_builder_i)
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
@ -1921,10 +1976,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
for i, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
|
||||
kv_cache_spec = kv_cache_group.kv_cache_spec
|
||||
for layer_name in kv_cache_group.layer_names:
|
||||
tensor_config = kv_cache_config.tensors[layer_name]
|
||||
@ -1939,7 +1995,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
# the min of all `num_blocks`. Verify it here.
|
||||
assert num_blocks >= kv_cache_config.num_blocks
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
kv_cache_shape = self.attn_backends[i].get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
@ -1959,11 +2015,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
|
||||
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
|
||||
weakref.proxy(self),
|
||||
kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
self.input_batch.block_table)
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
|
@ -171,19 +171,10 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# req_id -> (input_id -> encoder_output)
|
||||
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
|
||||
# self.input_batch: InputBatch # Persistent batch.
|
||||
|
||||
# Request states.
|
||||
self.requests: dict[str, CachedRequestState] = {}
|
||||
# Persistent batch.
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_blocks_per_req=self.max_num_blocks_per_req,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
|
||||
# Cached torch/numpy tensor
|
||||
# The pytorch tensor and numpy array share the same buffer.
|
||||
@ -199,7 +190,7 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
self.block_table_cpu = torch.zeros(
|
||||
(self.max_num_reqs, self.max_num_blocks_per_req),
|
||||
dtype=self.input_batch.block_table.get_cpu_tensor().dtype,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
|
||||
self.query_start_loc_cpu = torch.zeros(self.max_num_tokens + 1,
|
||||
@ -524,12 +515,12 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
# NOTE(woosuk): We use torch.index_select instead of np.take here
|
||||
# because torch.index_select is much faster than np.take for large
|
||||
# tensors.
|
||||
block_table_cpu = self.input_batch.block_table.get_cpu_tensor()
|
||||
block_table_cpu = self.input_batch.block_table[0].get_cpu_tensor()
|
||||
block_numbers = block_table_cpu.flatten()[block_table_indices].numpy()
|
||||
block_offsets = positions_np % self.block_size
|
||||
np.add(block_numbers * self.block_size,
|
||||
block_offsets,
|
||||
out=self.input_batch.block_table.
|
||||
out=self.input_batch.block_table[0].
|
||||
slot_mapping_np[:total_num_scheduled_tokens])
|
||||
|
||||
# Prepare the attention metadata.
|
||||
@ -554,15 +545,15 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.position_ids = self.positions_cpu[:
|
||||
padded_total_num_scheduled_tokens].to(
|
||||
self.device)
|
||||
self.input_batch.block_table.slot_mapping_cpu[
|
||||
self.input_batch.block_table[0].slot_mapping_cpu[
|
||||
total_num_scheduled_tokens:] = _PAD_SLOT_ID
|
||||
slot_mapping = (
|
||||
self.input_batch.block_table.
|
||||
self.input_batch.block_table[0].
|
||||
slot_mapping_cpu[:padded_total_num_scheduled_tokens].to(
|
||||
self.device))
|
||||
block_tables = self.block_table_cpu[:self.max_num_reqs]
|
||||
block_tables[:num_reqs, :self.max_num_blocks_per_req] = (
|
||||
self.input_batch.block_table.get_cpu_tensor()[:num_reqs])
|
||||
self.input_batch.block_table[0].get_cpu_tensor()[:num_reqs])
|
||||
block_tables = block_tables.to(self.device)
|
||||
query_start_loc = self.query_start_loc_cpu[:self.max_num_reqs + 1].to(
|
||||
self.device)
|
||||
@ -1263,6 +1254,19 @@ class TPUModelRunner(LoRAModelRunnerMixin):
|
||||
"Hybrid models with more than one KV cache type are not "
|
||||
"supported yet.")
|
||||
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=self.max_model_len,
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_size=kv_cache_config.kv_cache_groups[0].kv_cache_spec.
|
||||
block_size,
|
||||
)
|
||||
assert self.block_table_cpu.dtype == self.input_batch.block_table[
|
||||
0].get_cpu_tensor().dtype
|
||||
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups:
|
||||
|
Reference in New Issue
Block a user