mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
4 Commits
v0.11.0rc1
...
split_kv_c
Author | SHA1 | Date | |
---|---|---|---|
6e1e31a66a | |||
50e80db4ef | |||
d3d6afb355 | |||
808fa43d76 |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@ -37,8 +37,8 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/vllm/v1/attention/backends/triton_attn.py @tdoublep
|
||||
/vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat @heheda12345 @ApostaC
|
||||
/vllm/v1/kv_cache_interface.py @heheda12345
|
||||
/vllm/v1/worker/kv_cache_initializer_mixin.py @heheda12345
|
||||
/vllm/v1/offloading @ApostaC
|
||||
|
||||
# Test ownership
|
||||
/.buildkite/lm-eval-harness @mgoin @simon-mo
|
||||
/tests/async_engine @njhill @robertgshaw2-redhat @simon-mo
|
||||
|
@ -7,7 +7,6 @@ import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
@ -18,18 +17,14 @@ from tqdm import tqdm
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.cuda_graph import CUDAGraphWrapper
|
||||
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
|
||||
from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, update_config)
|
||||
from vllm.distributed.eplb.eplb_state import EplbState
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.distributed.kv_transfer import has_kv_transfer_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tp_group, graph_capture, is_global_first_rank,
|
||||
prepare_communication_buffer_for_model)
|
||||
@ -37,7 +32,6 @@ from vllm.forward_context import (BatchDescriptor, DPMetadata,
|
||||
set_forward_context)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader
|
||||
from vllm.model_executor.models.interfaces import (is_mixture_of_experts,
|
||||
@ -54,7 +48,7 @@ from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
GiB_bytes, LazyLoader, check_use_alibi, get_dtype_size,
|
||||
GiB_bytes, LazyLoader, check_use_alibi,
|
||||
is_pin_memory_available, round_up, supports_dynamo)
|
||||
from vllm.v1.attention.backends.flash_attn import AttentionMetadata
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
@ -70,8 +64,7 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
KVCacheSpec, SlidingWindowSpec)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
DraftTokenIds, LogprobsLists, LogprobsTensors,
|
||||
@ -88,6 +81,7 @@ from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.utils import CpuGpuBuffer, record_function_or_nullcontext
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.gpu_ubatch_wrapper import UBatchWrapper
|
||||
from vllm.v1.worker.kv_cache_initializer_mixin import KVCacheInitializerMixin
|
||||
from vllm.v1.worker.kv_connector_model_runner_mixin import (
|
||||
KVConnectorModelRunnerMixin)
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
@ -95,10 +89,8 @@ from vllm.v1.worker.ubatch_splitting import get_dp_padding_ubatch, ubatch_split
|
||||
from vllm.v1.worker.ubatch_utils import UBatchSlice, UBatchSlices
|
||||
from vllm.v1.worker.utils import is_residual_scattered_for_sp
|
||||
|
||||
from .utils import (AttentionGroup, MultiModalBudget,
|
||||
add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache,
|
||||
gather_mm_placeholders, sanity_check_mm_encoder_outputs,
|
||||
scatter_mm_placeholders)
|
||||
from .utils import (AttentionGroup, MultiModalBudget, gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
@ -163,7 +155,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
||||
return output
|
||||
|
||||
|
||||
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
class GPUModelRunner(KVCacheInitializerMixin, LoRAModelRunnerMixin,
|
||||
KVConnectorModelRunnerMixin):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -255,7 +248,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.kv_caches: list[torch.Tensor] = []
|
||||
# indexes: [kv_cache_group_id][attn_group]
|
||||
self.attn_groups: list[list[AttentionGroup]] = []
|
||||
# self.kv_cache_config: KVCacheConfig
|
||||
# a fake value to satisfy the type checker
|
||||
self.kv_cache_config: KVCacheConfig = cast(KVCacheConfig, None)
|
||||
|
||||
# mm_hash -> encoder_output
|
||||
self.encoder_cache: dict[str, torch.Tensor] = {}
|
||||
@ -3529,418 +3523,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
else:
|
||||
self.reorder_batch_threshold = reorder_batch_threshold_i
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
if block_sizes != [self.cache_config.block_size]:
|
||||
assert self.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
self.input_batch = InputBatch(
|
||||
max_num_reqs=self.max_num_reqs,
|
||||
max_model_len=max(self.max_model_len, self.max_encoder_len),
|
||||
max_num_batched_tokens=self.max_num_tokens,
|
||||
device=self.device,
|
||||
pin_memory=self.pin_memory,
|
||||
vocab_size=self.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(self.vllm_config.speculative_config),
|
||||
logitsprocs=self.input_batch.logitsprocs,
|
||||
is_pooling_model=self.is_pooling_model,
|
||||
num_speculative_tokens=(
|
||||
self.vllm_config.speculative_config.num_speculative_tokens
|
||||
if self.vllm_config.speculative_config else 0),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=self.device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _attn_group_iterator(self) -> Iterator[AttentionGroup]:
|
||||
return itertools.chain.from_iterable(self.attn_groups)
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
if not self.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(self.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield self.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = \
|
||||
attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
# The allocation respects the backend-defined stride order
|
||||
# to ensure the semantic remains consistent for each
|
||||
# backend. We first obtain the generic kv cache shape and
|
||||
# then permute it according to the stride order which could
|
||||
# result in a non-contiguous tensor.
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
# Maintain original KV shape view.
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
(num_blocks, 2, ...).
|
||||
|
||||
Args:
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
and kv_cache.shape[0] == 2):
|
||||
assert kv_cache.shape[1] != 2, \
|
||||
"Fail to determine whether the layout is " \
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size,
|
||||
*kv_cache.stride()[2:]))
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
# Initialize the memory buffer for KV cache
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
# Change the memory buffer to the desired shape
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
# Set up cross-layer KV cache sharing
|
||||
for layer_name, target_layer_name in self.shared_kv_cache_layers.items(
|
||||
):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
if not self.shared_kv_cache_layers:
|
||||
# No cross-layer KV sharing, return
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
self.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if self.cache_config.kv_sharing_fast_prefill:
|
||||
# In You Only Cache Once (https://arxiv.org/abs/2405.05254) or other
|
||||
# similar KV sharing setups, only the layers that generate KV caches
|
||||
# are involved in the prefill phase, enabling prefill to early exit.
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
Attention)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in self.shared_kv_cache_layers:
|
||||
self.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
layer_name)
|
||||
else:
|
||||
break
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if self.speculative_config and self.speculative_config.use_eagle():
|
||||
assert isinstance(self.drafter, EagleProposer)
|
||||
# validate all draft model layers belong to the same kv cache
|
||||
# group
|
||||
self.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
if self.device.type == 'xpu':
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
if self.dcp_world_size > 1:
|
||||
layer_names = self.attn_groups[0][0].layer_names
|
||||
layers = get_layers_from_vllm_config(self.vllm_config,
|
||||
AttentionLayerBase,
|
||||
layer_names)
|
||||
for layer in layers.values():
|
||||
assert layer.impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer.impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode.")
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
self.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
self.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
elif self.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
attention_chunk_size=self.attention_chunk_size,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(self.vllm_config, MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (self.vllm_config.speculative_config is not None
|
||||
and self.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = self.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
self.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
self.speculative_config.num_speculative_tokens
|
||||
if self.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
||||
|
||||
def _to_list(self, sampled_token_ids: torch.Tensor) -> list[list[int]]:
|
||||
# This is a short term mitigation for issue mentioned in
|
||||
# https://github.com/vllm-project/vllm/issues/22754.
|
||||
|
484
vllm/v1/worker/kv_cache_initializer_mixin.py
Normal file
484
vllm/v1/worker/kv_cache_initializer_mixin.py
Normal file
@ -0,0 +1,484 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from copy import deepcopy
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.attention.layers.chunked_local_attention import ChunkedLocalAttention
|
||||
from vllm.config import get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.utils import get_dtype_size
|
||||
# yapf: disable
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
ChunkedLocalAttentionSpec,
|
||||
CrossAttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, SlidingWindowSpec)
|
||||
# yapf: enable
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
from .utils import (AttentionGroup, add_kv_sharing_layers_to_kv_cache_groups,
|
||||
bind_kv_cache)
|
||||
|
||||
|
||||
class _KVCacheInitializerSelf(Protocol):
|
||||
cache_config: Any
|
||||
max_num_reqs: int
|
||||
max_model_len: int
|
||||
max_encoder_len: int
|
||||
max_num_tokens: int
|
||||
device: Any
|
||||
pin_memory: bool
|
||||
model_config: Any
|
||||
vllm_config: Any
|
||||
input_batch: InputBatch
|
||||
is_pooling_model: bool
|
||||
shared_kv_cache_layers: dict[str, str]
|
||||
kv_sharing_fast_prefill_eligible_layers: set[str]
|
||||
attention_chunk_size: int
|
||||
runner_only_attn_layers: set[str]
|
||||
kv_cache_dtype: torch.dtype
|
||||
kv_cache_config: KVCacheConfig
|
||||
compilation_config: Any
|
||||
kv_caches: Any
|
||||
speculative_config: Any
|
||||
drafter: Any
|
||||
dcp_world_size: int
|
||||
attn_groups: list[list[AttentionGroup]]
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
...
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Defined as a mixin for GPUModelRunner
|
||||
class KVCacheInitializerMixin:
|
||||
|
||||
def _runner(self) -> _KVCacheInitializerSelf:
|
||||
return cast(_KVCacheInitializerSelf, self)
|
||||
|
||||
def may_reinitialize_input_batch(self,
|
||||
kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Re-initialize the input batch if the block sizes are different from
|
||||
`[self.cache_config.block_size]`. This usually happens when there
|
||||
are multiple KV cache groups.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache configuration.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_sizes = [
|
||||
kv_cache_group.kv_cache_spec.block_size
|
||||
for kv_cache_group in kv_cache_config.kv_cache_groups
|
||||
]
|
||||
if block_sizes != [runner.cache_config.block_size]:
|
||||
assert runner.cache_config.cpu_offload_gb == 0, (
|
||||
"Cannot re-initialize the input batch when CPU weight "
|
||||
"offloading is enabled. See https://github.com/vllm-project/vllm/pull/18298 " # noqa: E501
|
||||
"for more details.")
|
||||
runner.input_batch = InputBatch(
|
||||
max_num_reqs=runner.max_num_reqs,
|
||||
max_model_len=max(runner.max_model_len,
|
||||
runner.max_encoder_len),
|
||||
max_num_batched_tokens=runner.max_num_tokens,
|
||||
device=runner.device,
|
||||
pin_memory=runner.pin_memory,
|
||||
vocab_size=runner.model_config.get_vocab_size(),
|
||||
block_sizes=block_sizes,
|
||||
is_spec_decode=bool(runner.vllm_config.speculative_config),
|
||||
logitsprocs=runner.input_batch.logitsprocs,
|
||||
is_pooling_model=runner.is_pooling_model,
|
||||
num_speculative_tokens=(runner.vllm_config.speculative_config.
|
||||
num_speculative_tokens if
|
||||
runner.vllm_config.speculative_config
|
||||
else 0),
|
||||
)
|
||||
|
||||
def _allocate_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initializes the KV cache buffer with the correct size. The buffer needs
|
||||
to be reshaped to the desired shape before being used by the models.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor] = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
dtype=torch.int8,
|
||||
device=runner.device)
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
kv_cache_raw_tensors[layer_name] = tensor
|
||||
|
||||
layer_names = set()
|
||||
for group in kv_cache_config.kv_cache_groups:
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner.runner_only_attn_layers:
|
||||
continue
|
||||
layer_names.add(layer_name)
|
||||
assert layer_names == set(kv_cache_raw_tensors.keys(
|
||||
)), "Some layers are not correctly initialized"
|
||||
return kv_cache_raw_tensors
|
||||
|
||||
def _kv_cache_spec_attn_group_iterator(
|
||||
self) -> Iterator[tuple[KVCacheSpec, AttentionGroup]]:
|
||||
runner = self._runner()
|
||||
if not runner.kv_cache_config.kv_cache_groups:
|
||||
return
|
||||
for kv_cache_spec_id, attn_groups in enumerate(runner.attn_groups):
|
||||
for attn_group in attn_groups:
|
||||
yield runner.kv_cache_config.kv_cache_groups[
|
||||
kv_cache_spec_id].kv_cache_spec, attn_group
|
||||
|
||||
def _reshape_kv_cache_tensors(
|
||||
self,
|
||||
kv_cache_config: KVCacheConfig,
|
||||
kv_cache_raw_tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshape the KV cache tensors to the desired shape and dtype.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
kv_cache_raw_tensors: The KV cache buffer of each layer, with
|
||||
correct size but uninitialized shape.
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_caches: dict[str, torch.Tensor] = {}
|
||||
has_attn, has_mamba = False, False
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in runner.runner_only_attn_layers:
|
||||
continue
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
assert raw_tensor.numel() % kv_cache_spec.page_size_bytes == 0
|
||||
num_blocks = (raw_tensor.numel() //
|
||||
kv_cache_spec.page_size_bytes)
|
||||
if isinstance(kv_cache_spec, AttentionSpec):
|
||||
has_attn = True
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
try:
|
||||
kv_cache_stride_order = \
|
||||
attn_backend.get_kv_cache_stride_order()
|
||||
assert len(kv_cache_stride_order) == len(
|
||||
kv_cache_shape)
|
||||
except (AttributeError, NotImplementedError):
|
||||
kv_cache_stride_order = tuple(
|
||||
range(len(kv_cache_shape)))
|
||||
kv_cache_shape = tuple(kv_cache_shape[i]
|
||||
for i in kv_cache_stride_order)
|
||||
inv_order = [
|
||||
kv_cache_stride_order.index(i)
|
||||
for i in range(len(kv_cache_stride_order))
|
||||
]
|
||||
kv_caches[layer_name] = kv_cache_raw_tensors[
|
||||
layer_name].view(dtype).view(kv_cache_shape).permute(
|
||||
*inv_order)
|
||||
elif isinstance(kv_cache_spec, MambaSpec):
|
||||
has_mamba = True
|
||||
raw_tensor = kv_cache_raw_tensors[layer_name]
|
||||
state_tensors = []
|
||||
storage_offset_bytes = 0
|
||||
for (shape, dtype) in zip(kv_cache_spec.shapes,
|
||||
kv_cache_spec.dtypes):
|
||||
dtype_size = get_dtype_size(dtype)
|
||||
num_element_per_page = (
|
||||
kv_cache_spec.page_size_bytes // dtype_size)
|
||||
target_shape = (num_blocks, *shape)
|
||||
stride = torch.empty(target_shape).stride()
|
||||
target_stride = (num_element_per_page, *stride[1:])
|
||||
assert storage_offset_bytes % dtype_size == 0
|
||||
tensor = torch.as_strided(
|
||||
raw_tensor.view(dtype),
|
||||
size=target_shape,
|
||||
stride=target_stride,
|
||||
storage_offset=storage_offset_bytes // dtype_size,
|
||||
)
|
||||
state_tensors.append(tensor)
|
||||
storage_offset_bytes += stride[0] * dtype_size
|
||||
|
||||
kv_caches[layer_name] = state_tensors
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
if has_attn and has_mamba:
|
||||
self._update_hybrid_attention_mamba_layout(kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def _update_hybrid_attention_mamba_layout(
|
||||
self, kv_caches: dict[str, torch.Tensor]) -> None:
|
||||
"""
|
||||
Update the layout of attention layers from (2, num_blocks, ...) to
|
||||
(num_blocks, 2, ...).
|
||||
|
||||
Args:
|
||||
kv_caches: The KV cache buffer of each layer.
|
||||
"""
|
||||
|
||||
for kv_cache_spec, group in self._kv_cache_spec_attn_group_iterator():
|
||||
for layer_name in group.layer_names:
|
||||
kv_cache = kv_caches[layer_name]
|
||||
if (isinstance(kv_cache_spec, AttentionSpec)
|
||||
and kv_cache.shape[0] == 2):
|
||||
assert kv_cache.shape[1] != 2, \
|
||||
"Fail to determine whether the layout is " \
|
||||
"(2, num_blocks, ...) or (num_blocks, 2, ...) for " \
|
||||
f"a tensor of shape {kv_cache.shape}"
|
||||
hidden_size = kv_cache.shape[2:].numel()
|
||||
kv_cache.as_strided_(size=kv_cache.shape,
|
||||
stride=(hidden_size, 2 * hidden_size,
|
||||
*kv_cache.stride()[2:]))
|
||||
|
||||
def initialize_kv_cache_tensors(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Initialize the memory buffer for KV cache.
|
||||
|
||||
Args:
|
||||
kv_cache_config: The KV cache config
|
||||
Returns:
|
||||
Dict[str, torch.Tensor]: A map between layer names to their
|
||||
corresponding memory buffer for KV cache.
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_raw_tensors = self._allocate_kv_cache_tensors(kv_cache_config)
|
||||
kv_caches = self._reshape_kv_cache_tensors(kv_cache_config,
|
||||
kv_cache_raw_tensors)
|
||||
|
||||
for layer_name, target_layer_name in (
|
||||
runner.shared_kv_cache_layers.items()):
|
||||
logger.debug("%s reuses KV cache of %s", layer_name,
|
||||
target_layer_name)
|
||||
kv_caches[layer_name] = kv_caches[target_layer_name]
|
||||
|
||||
bind_kv_cache(kv_caches,
|
||||
runner.compilation_config.static_forward_context,
|
||||
runner.kv_caches)
|
||||
return kv_caches
|
||||
|
||||
def maybe_add_kv_sharing_layers_to_kv_cache_groups(
|
||||
self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Add layers that re-use KV cache to KV cache group of its target layer.
|
||||
Mapping of KV cache tensors happens in `initialize_kv_cache_tensors()`
|
||||
"""
|
||||
runner = self._runner()
|
||||
if not runner.shared_kv_cache_layers:
|
||||
return
|
||||
|
||||
add_kv_sharing_layers_to_kv_cache_groups(
|
||||
runner.shared_kv_cache_layers,
|
||||
kv_cache_config.kv_cache_groups,
|
||||
runner.runner_only_attn_layers,
|
||||
)
|
||||
|
||||
if runner.cache_config.kv_sharing_fast_prefill:
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name in reversed(attn_layers):
|
||||
if layer_name in runner.shared_kv_cache_layers:
|
||||
runner.kv_sharing_fast_prefill_eligible_layers.add(
|
||||
layer_name)
|
||||
else:
|
||||
break
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_size = runner.vllm_config.cache_config.block_size
|
||||
use_mla = runner.vllm_config.model_config.use_mla
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
runner.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
runner.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize KV cache based on `kv_cache_config`.
|
||||
|
||||
Args:
|
||||
kv_cache_config: Configuration for the KV cache, including the KV
|
||||
cache size of each layer
|
||||
"""
|
||||
runner = self._runner()
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
runner.kv_cache_config = kv_cache_config
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config)
|
||||
runner.initialize_attn_backend(kv_cache_config)
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
|
||||
if runner.speculative_config and runner.speculative_config.use_eagle():
|
||||
assert isinstance(runner.drafter, EagleProposer)
|
||||
runner.drafter.validate_same_kv_cache_group(kv_cache_config)
|
||||
|
||||
if has_kv_transfer_group():
|
||||
get_kv_transfer_group().register_kv_caches(kv_caches)
|
||||
if runner.device.type == 'xpu':
|
||||
get_kv_transfer_group().set_host_xfer_buffer_ops(
|
||||
copy_kv_blocks)
|
||||
|
||||
if runner.dcp_world_size > 1:
|
||||
layer_names = runner.attn_groups[0][0].layer_names
|
||||
layers = get_layers_from_vllm_config(
|
||||
runner.vllm_config,
|
||||
AttentionLayerBase, # type: ignore[type-abstract]
|
||||
layer_names,
|
||||
)
|
||||
for layer in layers.values():
|
||||
layer_impl = cast(Any, layer).impl
|
||||
assert layer_impl.need_to_return_lse_for_decode, (
|
||||
"DCP requires attention impls to return"
|
||||
" the softmax lse for decode, but the impl "
|
||||
f"{layer_impl.__class__.__name__} "
|
||||
"does not return the softmax lse for decode.")
|
||||
|
||||
def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
|
||||
"""
|
||||
Generates the KVCacheSpec by parsing the kv cache format from each
|
||||
Attention module in the static forward context.
|
||||
Returns:
|
||||
KVCacheSpec: A dictionary mapping layer names to their KV cache
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
runner = self._runner()
|
||||
block_size = runner.vllm_config.cache_config.block_size
|
||||
use_mla = runner.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if (kv_tgt_layer :=
|
||||
attn_module.kv_sharing_target_layer_name) is not None:
|
||||
# The layer doesn't need its own KV cache and will use that of
|
||||
# the target layer. We skip creating a KVCacheSpec for it, so
|
||||
# that KV cache management logic will act as this layer does
|
||||
# not exist, and doesn't allocate KV cache for the layer. This
|
||||
# enables the memory saving of cross-layer kv sharing, allowing
|
||||
# a given amount of memory to accommodate longer context lengths
|
||||
# or enable more requests to be processed simultaneously.
|
||||
runner.shared_kv_cache_layers[layer_name] = kv_tgt_layer
|
||||
continue
|
||||
|
||||
# TODO(lucas): move the attention specs into the model layers like
|
||||
# the attention backends
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
elif runner.attention_chunk_size is not None \
|
||||
and isinstance(attn_module, ChunkedLocalAttention):
|
||||
kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
attention_chunk_size=runner.attention_chunk_size,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||
kv_cache_spec[layer_name] = CrossAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=runner.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
continue
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown attention type: {attn_module.attn_type}")
|
||||
|
||||
mamba_layers = get_layers_from_vllm_config(runner.vllm_config,
|
||||
MambaBase)
|
||||
if len(mamba_layers) > 0:
|
||||
if (runner.vllm_config.speculative_config is not None
|
||||
and runner.vllm_config.model_config.hf_config.model_type
|
||||
not in ["qwen3_next"]):
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if runner.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
max_model_len = runner.vllm_config.model_config.max_model_len
|
||||
|
||||
page_size_padded = (
|
||||
runner.vllm_config.cache_config.mamba_page_size_padded)
|
||||
|
||||
# Set block_size to max_model_len, so that mamba model will always
|
||||
# have only one block in the KV cache.
|
||||
for layer_name, mamba_module in mamba_layers.items():
|
||||
kv_cache_spec[layer_name] = MambaSpec(
|
||||
shapes=mamba_module.get_state_shape(),
|
||||
dtypes=mamba_module.get_state_dtype(),
|
||||
block_size=max_model_len,
|
||||
page_size_padded=page_size_padded,
|
||||
mamba_type=mamba_module.mamba_type,
|
||||
num_speculative_blocks=(
|
||||
runner.speculative_config.num_speculative_tokens
|
||||
if runner.speculative_config else 0),
|
||||
)
|
||||
|
||||
return kv_cache_spec
|
Reference in New Issue
Block a user