mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Get a specific type of layer from forward context (#17222)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@ -38,7 +38,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype,
|
||||
make_tensor_with_pad)
|
||||
@ -140,12 +140,10 @@ def get_per_layer_parameters(
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: Dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
assert isinstance(layer, Attention)
|
||||
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
|
@ -3445,7 +3445,8 @@ class CompilationConfig(BaseModel):
|
||||
compilation_time: float = PrivateAttr
|
||||
|
||||
# Per-model forward context
|
||||
# Map from layer name to the attention cls
|
||||
# Map from layer name to layer objects that need to be accessed outside
|
||||
# model code, e.g., Attention, FusedMOE when dp_size>1.
|
||||
static_forward_context: dict[str, Any] = PrivateAttr
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -4079,3 +4080,16 @@ def assert_hashable(text):
|
||||
f"vLLM tried to hash some configs that may have Python objects ids "
|
||||
f"in them. This is a bug, please file an issue. "
|
||||
f"Text being hashed: {text}")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def get_layers_from_vllm_config(vllm_config: VllmConfig,
|
||||
layer_type: type[T]) -> dict[str, T]:
|
||||
return {
|
||||
layer_name: layer
|
||||
for layer_name, layer in
|
||||
vllm_config.compilation_config.static_forward_context.items()
|
||||
if isinstance(layer, layer_type)
|
||||
}
|
||||
|
@ -14,7 +14,8 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionType)
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config import (VllmConfig, get_current_vllm_config,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
|
||||
|
||||
@ -81,12 +82,10 @@ def get_per_layer_parameters(
|
||||
to use during `plan`.
|
||||
"""
|
||||
|
||||
layers = vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(vllm_config, Attention)
|
||||
per_layer_params: dict[str, PerLayerParameters] = {}
|
||||
|
||||
for key, layer in layers.items():
|
||||
assert isinstance(layer, Attention)
|
||||
|
||||
impl = layer.impl
|
||||
assert isinstance(impl, FlashInferImpl)
|
||||
|
||||
|
@ -12,13 +12,13 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.config import (CompilationLevel, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group)
|
||||
from vllm.distributed.parallel_state import get_pp_group, graph_capture
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -1733,17 +1733,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
if isinstance(attn_module, FusedMoE):
|
||||
continue
|
||||
|
||||
# TODO: Support other attention modules, e.g., sliding window,
|
||||
# cross-attention
|
||||
assert isinstance(attn_module, Attention)
|
||||
for layer_name, attn_module in layers.items():
|
||||
# TODO: Support other attention modules, e.g., cross-attention
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
|
@ -17,7 +17,7 @@ import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
@ -429,11 +429,10 @@ class TPUModelRunner:
|
||||
format. Layers that do not need KV cache are not included.
|
||||
"""
|
||||
|
||||
forward_ctx = self.vllm_config.compilation_config.static_forward_context
|
||||
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
for layer_name, attn_module in forward_ctx.items():
|
||||
assert isinstance(attn_module, Attention)
|
||||
for layer_name, attn_module in layers.items():
|
||||
if attn_module.attn_type == AttentionType.DECODER:
|
||||
if attn_module.sliding_window is not None:
|
||||
kv_cache_spec[layer_name] = SlidingWindowSpec(
|
||||
|
Reference in New Issue
Block a user