[V1] Allocate kv_cache with stride order for V1 (#18775)

Signed-off-by: nicklucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-05-29 19:54:16 +02:00
committed by GitHub
parent d58f9c7f7a
commit 32ce3cf7c9
2 changed files with 81 additions and 16 deletions

View File

@ -1,7 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import random
import pytest
from vllm.attention import Attention
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, VllmConfig)
from vllm.sampling_params import SamplingParams
@ -13,27 +16,30 @@ from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
BLOCK_SIZE = 16
NUM_BLOCKS = 10
def initialize_kv_cache(runner: GPUModelRunner):
"""
Only perform necessary steps in GPUModelRunner.initialize_kv_cache()
"""
attn_spec = FullAttentionSpec(
block_size=BLOCK_SIZE,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
)
tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS
kv_cache_config = KVCacheConfig(
num_blocks=10,
num_blocks=NUM_BLOCKS,
tensors={
"layer.0": KVCacheTensor(size=1024),
"layer.0": KVCacheTensor(size=tensor_size),
},
kv_cache_groups=[
KVCacheGroupSpec(
layer_names=["layer.0"],
kv_cache_spec=FullAttentionSpec(
block_size=16,
num_kv_heads=runner.model_config.get_num_kv_heads(
runner.parallel_config),
head_size=runner.model_config.get_head_size(),
dtype=runner.kv_cache_dtype,
use_mla=False,
))
KVCacheGroupSpec(layer_names=["layer.0"], kv_cache_spec=attn_spec)
])
runner.kv_cache_config = kv_cache_config
runner.input_batch = InputBatch(
@ -65,7 +71,7 @@ def model_runner():
seed=42,
)
cache_config = CacheConfig(
block_size=16,
block_size=BLOCK_SIZE,
gpu_memory_utilization=0.9,
swap_space=0,
cache_dtype="auto",
@ -77,6 +83,10 @@ def model_runner():
scheduler_config=scheduler_config,
parallel_config=parallel_config,
)
num_heads = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
vllm_config.compilation_config.static_forward_context[
"layer.0"] = Attention(num_heads, head_size, 0.1)
device = "cuda"
runner = GPUModelRunner(vllm_config, device)
@ -321,3 +331,38 @@ def test_update_states_request_unscheduled(model_runner):
assert _is_req_added(model_runner, req_ids[1])
assert not _is_req_scheduled(model_runner, req_ids[1])
def test_kv_cache_stride_order(monkeypatch, model_runner):
# This test checks if GPUModelRunner initializes correctly when an attention
# backend enforces a non-default KV cache stride order.
n_heads = model_runner.model_config.get_num_kv_heads(
model_runner.parallel_config)
expected_kv_cache_shape = [
2, NUM_BLOCKS, BLOCK_SIZE, n_heads,
model_runner.model_config.get_head_size()
]
# TODO mla test
default_stride = list(range(5))
# Permutation that gets you back to expected kv shape
rnd_stride = tuple(random.sample(default_stride, len(default_stride)))
def rnd_stride_order():
return rnd_stride
# Patch the attention backend class and re-trigger the KV cache creation.
for attn_backend in model_runner.attn_backends:
monkeypatch.setattr(attn_backend, "get_kv_cache_stride_order",
rnd_stride_order)
model_runner.attn_backends = []
model_runner.attn_metadata_builders = []
model_runner.initialize_kv_cache(model_runner.kv_cache_config)
# Shape is unchanged, but layout may differ
kv_cache_shape = model_runner.kv_caches[0].shape
assert list(kv_cache_shape) == expected_kv_cache_shape
if default_stride == rnd_stride:
assert all(kv.is_contiguous() for kv in model_runner.kv_caches)
else:
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)

View File

@ -2033,9 +2033,29 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
kv_caches[layer_name] = torch.zeros(kv_cache_shape,
dtype=dtype,
device=self.device)
try:
kv_cache_stride_order = self.attn_backends[
i].get_kv_cache_stride_order()
assert len(kv_cache_stride_order) == len(
kv_cache_shape)
except (AttributeError, NotImplementedError):
kv_cache_stride_order = tuple(
range(len(kv_cache_shape)))
# The allocation respects the backend-defined stride order
# to ensure the semantic remains consistent for each
# backend. We first obtain the generic kv cache shape and
# then permute it according to the stride order which could
# result in a non-contiguous tensor.
kv_cache_shape = tuple(kv_cache_shape[i]
for i in kv_cache_stride_order)
# Maintain original KV shape view.
inv_order = [
kv_cache_stride_order.index(i)
for i in range(len(kv_cache_stride_order))
]
kv_caches[layer_name] = torch.zeros(
kv_cache_shape, dtype=dtype,
device=self.device).permute(*inv_order)
else:
# TODO: add new branches when introducing more types of
# KV cache specs.