mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Refactor get_kv_cache_spec
into AttentionLayerBase
(#26587)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@ -16,6 +16,7 @@ from vllm.attention.backends.registry import _Backend, backend_name_to_enum
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.distributed.kv_transfer import (
|
||||
get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
@ -34,7 +35,16 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.models.vision import get_vit_attn_backend
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.utils import (
|
||||
direct_register_custom_op,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import (
|
||||
FullAttentionSpec,
|
||||
KVCacheSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
)
|
||||
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
logger = init_logger(__name__)
|
||||
@ -152,6 +162,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
@ -160,6 +171,9 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
calculate_kv_scales = False
|
||||
self.kv_cache_torch_dtype = kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, (
|
||||
@ -256,7 +270,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
@ -276,9 +290,7 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([])
|
||||
for _ in range(
|
||||
get_current_vllm_config().parallel_config.pipeline_parallel_size
|
||||
)
|
||||
for _ in range(vllm_config.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
# Initialize q/k/v range constants.
|
||||
@ -394,6 +406,30 @@ class Attention(nn.Module, AttentionLayerBase):
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Block size may get updated after model loading, refresh it
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
# Should not be called for enc-dec or encoder-only attention.
|
||||
assert self.attn_type == AttentionType.DECODER
|
||||
if self.sliding_window is not None:
|
||||
assert not vllm_config.model_config.use_mla, (
|
||||
"MLA is not supported for slidingwindow"
|
||||
)
|
||||
return SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
else:
|
||||
return FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
@ -749,6 +785,18 @@ class MLAAttention(nn.Module, AttentionLayerBase):
|
||||
def get_attn_backend(self) -> type[AttentionBackend]:
|
||||
return self.attn_backend
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
self.kv_cache_dtype, vllm_config.model_config
|
||||
)
|
||||
return MLAAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=self.head_size,
|
||||
dtype=kv_cache_dtype,
|
||||
cache_dtype_str=vllm_config.cache_config.cache_dtype,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||
|
@ -9,6 +9,7 @@ from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
AttentionCGSupport,
|
||||
@ -16,6 +17,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
make_local_attention_virtual_batches,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import ChunkedLocalAttentionSpec, KVCacheSpec
|
||||
|
||||
from ..layer import Attention
|
||||
|
||||
@ -67,6 +69,7 @@ class ChunkedLocalAttention(Attention):
|
||||
kv_sharing_target_layer_name: str | None = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.attention_chunk_size = attention_chunk_size
|
||||
dtype = torch.get_default_dtype()
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
@ -99,3 +102,13 @@ class ChunkedLocalAttention(Attention):
|
||||
kv_sharing_target_layer_name=kv_sharing_target_layer_name,
|
||||
attn_backend=attn_backend,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
assert self.attention_chunk_size
|
||||
return ChunkedLocalAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
|
@ -21,7 +21,7 @@ from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec
|
||||
from vllm.v1.kv_cache_interface import CrossAttentionSpec, KVCacheSpec
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -174,3 +174,11 @@ class CrossAttention(Attention):
|
||||
attn_type=AttentionType.ENCODER_DECODER,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return CrossAttentionSpec(
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
head_size=self.head_size,
|
||||
dtype=self.kv_cache_torch_dtype,
|
||||
)
|
||||
|
@ -14,10 +14,12 @@ from vllm.attention.backends.abstract import (
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.config.vllm import VllmConfig
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
subclass_attention_backend,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
@ -98,3 +100,7 @@ class EncoderOnlyAttention(Attention):
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
# Does not need KV cache
|
||||
return None
|
||||
|
@ -5,6 +5,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
@ -22,3 +25,11 @@ class AttentionLayerBase(ABC):
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
"""
|
||||
Get the KV cache spec for this layer.
|
||||
May be None if the layer does not need KV cache.
|
||||
"""
|
||||
pass
|
||||
|
@ -6,7 +6,9 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
@ -40,3 +42,30 @@ class MambaBase(AttentionLayerBase):
|
||||
def get_attn_backend(self) -> type["AttentionBackend"]:
|
||||
"""Get the attention backend class for this Mamba layer."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, ...]:
|
||||
pass
|
||||
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
|
||||
if (
|
||||
vllm_config.speculative_config is not None
|
||||
and vllm_config.model_config.hf_config.model_type not in ["qwen3_next"]
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet."
|
||||
)
|
||||
mamba_block_size = vllm_config.cache_config.mamba_block_size
|
||||
page_size_padded = vllm_config.cache_config.mamba_page_size_padded
|
||||
return MambaSpec(
|
||||
shapes=self.get_state_shape(),
|
||||
dtypes=self.get_state_dtype(),
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=self.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
vllm_config.speculative_config.num_speculative_tokens
|
||||
if vllm_config.speculative_config
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
@ -481,7 +481,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def get_kv_cache_spec(self) -> KVCacheSpec:
|
||||
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
|
||||
return MLAAttentionSpec( # Only has one vector instead of K + V
|
||||
block_size=self.cache_config.block_size,
|
||||
num_kv_heads=1,
|
||||
|
@ -137,6 +137,15 @@ def set_default_torch_num_threads(num_threads: int):
|
||||
torch.set_num_threads(old_num_threads)
|
||||
|
||||
|
||||
def kv_cache_dtype_str_to_dtype(
|
||||
kv_cache_dtype: str, model_config: ModelConfig
|
||||
) -> torch.dtype:
|
||||
if kv_cache_dtype == "auto":
|
||||
# Model config may not be specified for unit tests, default to float16
|
||||
return model_config.dtype if model_config else torch.half
|
||||
return STR_DTYPE_TO_TORCH_DTYPE[kv_cache_dtype]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
U = TypeVar("U")
|
||||
|
||||
|
@ -948,7 +948,7 @@ class EagleProposer:
|
||||
indexer_layers[first_layer]
|
||||
.get_attn_backend()
|
||||
.get_builder_cls()(
|
||||
indexer_layers[first_layer].get_kv_cache_spec(),
|
||||
indexer_layers[first_layer].get_kv_cache_spec(self.vllm_config),
|
||||
self.indexer_layer_names,
|
||||
self.vllm_config,
|
||||
self.device,
|
||||
|
@ -19,8 +19,6 @@ from tqdm import tqdm
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend, MultipleOf
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
@ -44,10 +42,8 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import BatchDescriptor, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.interfaces import (
|
||||
SupportsMultiModal,
|
||||
is_mixture_of_experts,
|
||||
@ -73,11 +69,11 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (
|
||||
STR_DTYPE_TO_TORCH_DTYPE,
|
||||
cdiv,
|
||||
check_use_alibi,
|
||||
get_dtype_size,
|
||||
is_pin_memory_available,
|
||||
kv_cache_dtype_str_to_dtype,
|
||||
length_from_prompt_token_ids_or_embeds,
|
||||
round_up,
|
||||
supports_dynamo,
|
||||
@ -106,7 +102,6 @@ from vllm.v1.kv_cache_interface import (
|
||||
KVCacheGroupSpec,
|
||||
KVCacheSpec,
|
||||
MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
SlidingWindowSpec,
|
||||
UniformTypeKVCacheSpecs,
|
||||
)
|
||||
@ -239,10 +234,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.device = device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.dtype = self.model_config.dtype
|
||||
if cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
else:
|
||||
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype]
|
||||
self.kv_cache_dtype = kv_cache_dtype_str_to_dtype(
|
||||
cache_config.cache_dtype, self.model_config
|
||||
)
|
||||
|
||||
self.is_pooling_model = model_config.runner_type == "pooling"
|
||||
self.enable_prompt_embeds = model_config.enable_prompt_embeds
|
||||
@ -4577,109 +4571,24 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
cache_dtype_str = self.vllm_config.cache_config.cache_dtype
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, AttentionLayerBase)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if isinstance(attn_module, Attention):
|
||||
if (
|
||||
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||
) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
assert not use_mla, "MLA is not supported for slidingwindow"
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
)
|
||||
elif self.attention_chunk_size is not None and isinstance(
|
||||
attn_module, ChunkedLocalAttention
|
||||
):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
)
|
||||
elif attn_module.attn_type in (
|
||||
AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY,
|
||||
):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
else:
|
||||
raise ValueError(f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
elif isinstance(attn_module, MLAAttention):
|
||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=1,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
cache_dtype_str=cache_dtype_str,
|
||||
)
|
||||
|
||||
elif isinstance(attn_module, MambaBase):
|
||||
if (
|
||||
self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet."
|
||||
)
|
||||
mamba_block_size = self.vllm_config.cache_config.mamba_block_size
|
||||
page_size_padded = self.vllm_config.cache_config.mamba_page_size_padded
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=attn_module.get_state_shape(),
|
||||
dtypes=attn_module.get_state_dtype(),
|
||||
block_size=mamba_block_size,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=attn_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config
|
||||
else 0
|
||||
),
|
||||
)
|
||||
|
||||
ds_indexer_layers = get_layers_from_vllm_config(
|
||||
self.vllm_config, DeepseekV32IndexerCache
|
||||
)
|
||||
for layer_name, ds_indexer_module in ds_indexer_layers.items():
|
||||
kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec()
|
||||
if isinstance(attn_module, Attention) and (
|
||||
kv_tgt_layer := attn_module.kv_sharing_target_layer_name
|
||||
):
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
# Skip modules that don't need KV cache (eg encoder-only attention)
|
||||
if spec := attn_module.get_kv_cache_spec(self.vllm_config):
|
||||
kv_cache_spec[layer_name] = spec
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
|
Reference in New Issue
Block a user