mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Allocate kv_cache with stride order for V1 (#18775)
Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
@ -1,7 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig, VllmConfig)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
NUM_BLOCKS = 10
|
||||
|
||||
|
||||
def initialize_kv_cache(runner: GPUModelRunner):
|
||||
"""
|
||||
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
|
||||
"""
|
||||
attn_spec = FullAttentionSpec(
|
||||
block_size=BLOCK_SIZE,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
)
|
||||
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
|
||||
kv_cache_config = KVCacheConfig(
|
||||
num_blocks=10,
|
||||
num_blocks=NUM_BLOCKS,
|
||||
tensors={
|
||||
"layer.0": KVCacheTensor(size=1024),
|
||||
"layer.0": KVCacheTensor(size=tensor_size),
|
||||
},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(
|
||||
layer_names=["layer.0"],
|
||||
kv_cache_spec=FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=runner.model_config.get_num_kv_heads(
|
||||
runner.parallel_config),
|
||||
head_size=runner.model_config.get_head_size(),
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=False,
|
||||
))
|
||||
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
|
||||
])
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
runner.input_batch = InputBatch(
|
||||
@ -65,7 +71,7 @@ def model_runner():
|
||||
seed=42,
|
||||
)
|
||||
cache_config = CacheConfig(
|
||||
block_size=16,
|
||||
block_size=BLOCK_SIZE,
|
||||
gpu_memory_utilization=0.9,
|
||||
swap_space=0,
|
||||
cache_dtype="auto",
|
||||
@ -77,6 +83,10 @@ def model_runner():
|
||||
scheduler_config=scheduler_config,
|
||||
parallel_config=parallel_config,
|
||||
)
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
head_size = model_config.get_head_size()
|
||||
vllm_config.compilation_config.static_forward_context[
|
||||
"layer.0"] = Attention(num_heads, head_size, 0.1)
|
||||
|
||||
device = "cuda"
|
||||
runner = GPUModelRunner(vllm_config, device)
|
||||
@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
|
||||
assert _is_req_added(model_runner, req_ids[1])
|
||||
assert not _is_req_scheduled(model_runner, req_ids[1])
|
||||
|
||||
|
||||
def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
# This test checks if GPUModelRunner initializes correctly when an attention
|
||||
# backend enforces a non-default KV cache stride order.
|
||||
n_heads = model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.parallel_config)
|
||||
expected_kv_cache_shape = [
|
||||
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
|
||||
model_runner.model_config.get_head_size()
|
||||
]
|
||||
# TODO mla test
|
||||
default_stride = list(range(5))
|
||||
# Permutation that gets you back to expected kv shape
|
||||
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
|
||||
|
||||
def rnd_stride_order():
|
||||
return rnd_stride
|
||||
|
||||
# Patch the attention backend class and re-trigger the KV cache creation.
|
||||
for attn_backend in model_runner.attn_backends:
|
||||
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
|
||||
rnd_stride_order)
|
||||
|
||||
model_runner.attn_backends = []
|
||||
model_runner.attn_metadata_builders = []
|
||||
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
|
||||
|
||||
# Shape is unchanged, but layout may differ
|
||||
kv_cache_shape = model_runner.kv_caches[0].shape
|
||||
assert list(kv_cache_shape) == expected_kv_cache_shape
|
||||
if default_stride == rnd_stride:
|
||||
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
else:
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
|
@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
try:
|
||||
kv_cache_stride_order = self.attn_backends[
|
||||
i].get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
# The allocation respects the backend-defined stride order
|
||||
# to ensure the semantic remains consistent for each
|
||||
# backend. We first obtain the generic kv cache shape and
|
||||
# then permute it according to the stride order which could
|
||||
# result in a non-contiguous tensor.
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = torch.zeros(
|
||||
kv_cache_shape, dtype=dtype,
|
||||
device=self.device).permute(*inv_order)
|
||||
else:
|
||||
# TODO: add new branches when introducing more types of
|
||||
# KV cache specs.
|
||||
|
Reference in New Issue
Block a user