[Misc] Refactor get_kv_cache_spec into AttentionLayerBase (#26587)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-10-18 15:51:21 +02:00
committed by GitHub
parent ab4be40fc5
commit b26b70bec4
10 changed files with 151 additions and 118 deletions

View File

@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
from vllm.attention.selector import get_attn_backend
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.config.vllm import VllmConfig
from vllm.distributed.kv_transfer import (
get_kv_transfer_group,
has_kv_transfer_group,
@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.utils import (
direct_register_custom_op,
kv_cache_dtype_str_to_dtype,
)
from vllm.v1.kv_cache_interface import (
FullAttentionSpec,
KVCacheSpec,
MLAAttentionSpec,
SlidingWindowSpec,
)
FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
else:
sliding_window = None
vllm_config = get_current_vllm_config()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase):
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
kv_cache_dtype, vllm_config.model_config
)
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, (
@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase):
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase):
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([])
for _ in range(
get_current_vllm_config().parallel_config.pipeline_parallel_size
)
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
]
# Initialize q/k/v range constants.
@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
assert self.attn_type == AttentionType.DECODER
if self.sliding_window is not None:
assert not vllm_config.model_config.use_mla, (
"MLA is not supported for slidingwindow"
)
return SlidingWindowSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
sliding_window=self.sliding_window,
)
else:
return FullAttentionSpec(
block_size=block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""
@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
self.kv_cache_dtype, vllm_config.model_config
)
return MLAAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=1,
head_size=self.head_size,
dtype=kv_cache_dtype,
cache_dtype_str=vllm_config.cache_config.cache_dtype,
)
def wait_for_kv_layer_from_connector(layer_name: str):
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():

View File

@ -9,6 +9,7 @@ from vllm import envs
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import (
make_local_attention_virtual_batches,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
from ..layer import Attention
@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name: str | None = None,
prefix: str = "",
):
self.attention_chunk_size = attention_chunk_size
dtype = torch.get_default_dtype()
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention):
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
attn_backend=attn_backend,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
assert self.attention_chunk_size
return ChunkedLocalAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
attention_chunk_size=self.attention_chunk_size,
)

View File

@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import CrossAttentionSpec
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
logger = init_logger(__name__)
@ -174,3 +174,11 @@ class CrossAttention(Attention):
attn_type=AttentionType.ENCODER_DECODER,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return CrossAttentionSpec(
block_size=vllm_config.cache_config.block_size,
num_kv_heads=self.num_kv_heads,
head_size=self.head_size,
dtype=self.kv_cache_torch_dtype,
)

View File

@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (
from vllm.attention.layer import Attention
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.config.vllm import VllmConfig
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
subclass_attention_backend,
)
from vllm.v1.kv_cache_interface import KVCacheSpec
@functools.lru_cache
@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention):
attn_type=AttentionType.ENCODER_ONLY,
**kwargs,
)
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
# Does not need KV cache
return None

View File

@ -5,6 +5,9 @@
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this layer."""
pass
@abstractmethod
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
"""
Get the KV cache spec for this layer.
May be None if the layer does not need KV cache.
"""
pass

View File

@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
import torch
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
def get_attn_backend(self) -> type["AttentionBackend"]:
"""Get the attention backend class for this Mamba layer."""
pass
@abstractmethod
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
pass
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
if (
vllm_config.speculative_config is not None
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = vllm_config.cache_config.mamba_block_size
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
return MambaSpec(
shapes=self.get_state_shape(),
dtypes=self.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=self.mamba_type,
num_speculative_blocks=(
vllm_config.speculative_config.num_speculative_tokens
if vllm_config.speculative_config
else 0
),
)

View File

@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def get_kv_cache_spec(self) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
return MLAAttentionSpec( # Only has one vector instead of K + V
block_size=self.cache_config.block_size,
num_kv_heads=1,

View File

@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
torch.set_num_threads(old_num_threads)
def kv_cache_dtype_str_to_dtype(
kv_cache_dtype: str, model_config: ModelConfig
) -> torch.dtype:
if kv_cache_dtype == "auto":
# Model config may not be specified for unit tests, default to float16
return model_config.dtype if model_config else torch.half
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
T = TypeVar("T")
U = TypeVar("U")

View File

@ -948,7 +948,7 @@ class EagleProposer:
indexer_layers[first_layer]
.get_attn_backend()
.get_builder_cls()(
indexer_layers[first_layer].get_kv_cache_spec(),
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
self.indexer_layer_names,
self.vllm_config,
self.device,

View File

@ -19,8 +19,6 @@ from tqdm import tqdm
import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
from vllm.attention.layer import MLAAttention
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
from vllm.compilation.counter import compilation_counter
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.interfaces import (
SupportsMultiModal,
is_mixture_of_experts,
@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.utils import (
STR_DTYPE_TO_TORCH_DTYPE,
cdiv,
check_use_alibi,
get_dtype_size,
is_pin_memory_available,
kv_cache_dtype_str_to_dtype,
length_from_prompt_token_ids_or_embeds,
round_up,
supports_dynamo,
@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import (
KVCacheGroupSpec,
KVCacheSpec,
MambaSpec,
MLAAttentionSpec,
SlidingWindowSpec,
UniformTypeKVCacheSpecs,
)
@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.device = device
self.pin_memory = is_pin_memory_available()
self.dtype = self.model_config.dtype
if cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype
else:
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
cache_config.cache_dtype, self.model_config
)
self.is_pooling_model = model_config.runner_type == "pooling"
self.enable_prompt_embeds = model_config.enable_prompt_embeds
@ -4577,109 +4571,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
format. Layers that do not need KV cache are not included.
"""
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
for layer_name, attn_module in attn_layers.items():
if isinstance(attn_module, Attention):
if (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
) is not None:
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if attn_module.sliding_window is not None:
assert not use_mla, "MLA is not supported for slidingwindow"
kv_cache_spec[layer_name] = SlidingWindowSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
sliding_window=attn_module.sliding_window,
)
elif self.attention_chunk_size is not None and isinstance(
attn_module, ChunkedLocalAttention
):
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
attention_chunk_size=self.attention_chunk_size,
)
else:
kv_cache_spec[layer_name] = FullAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
kv_cache_spec[layer_name] = CrossAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
)
elif attn_module.attn_type in (
AttentionType.ENCODER,
AttentionType.ENCODER_ONLY,
):
# encoder-only attention does not need KV cache.
continue
else:
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
elif isinstance(attn_module, MLAAttention):
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
cache_dtype_str=cache_dtype_str,
)
elif isinstance(attn_module, MambaBase):
if (
self.vllm_config.speculative_config is not None
and self.vllm_config.model_config.hf_config.model_type
not in ["qwen3_next"]
):
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet."
)
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
kv_cache_spec[layer_name] = MambaSpec(
shapes=attn_module.get_state_shape(),
dtypes=attn_module.get_state_dtype(),
block_size=mamba_block_size,
page_size_padded=page_size_padded,
mamba_type=attn_module.mamba_type,
num_speculative_blocks=(
self.speculative_config.num_speculative_tokens
if self.speculative_config
else 0
),
)
ds_indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache
)
for layer_name, ds_indexer_module in ds_indexer_layers.items():
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
if isinstance(attn_module, Attention) and (
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
):
# The layer doesn't need its own KV cache and will use that of
# the target layer. We skip creating a KVCacheSpec for it, so
# that KV cache management logic will act as this layer does
# not exist, and doesn't allocate KV cache for the layer. This
# enables the memory saving of cross-layer kv sharing, allowing
# a given amount of memory to accommodate longer context lengths
# or enable more requests to be processed simultaneously.
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
continue
# Skip modules that don't need KV cache (eg encoder-only attention)
if spec := attn_module.get_kv_cache_spec(self.vllm_config):
kv_cache_spec[layer_name] = spec
return kv_cache_spec