mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
### What this PR does / why we need it?
Fixes #3096
1. Fix kv cache initialization error of attention layer. There are some
models with layer name like `attn.attn`, instead of `self_attn`, but the
initialization of kv cache tensors only check for `self_attn` and
`attn.attn`, which leding to the error `AssertionError: Some layers are
not correctly initialized`
2. Set the default value of input arg `sampling_metadata` in
`compute_logits` for the modeling files in vllm-ascend. Thus fixing the
error `Qwen3NextForCausalLM.compute_logits() missing 1 required
positional argument: 'sampling_metadata'`
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
test locally with internlm
- vLLM version: v0.10.2
- vLLM main:
5aeb925452
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
204 lines
8.2 KiB
Python
204 lines
8.2 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PretrainedConfig
|
|
from vllm.attention.backends.abstract import AttentionMetadata
|
|
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
|
get_current_vllm_config)
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.models.deepseek_mtp import (
|
|
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
|
SharedHead)
|
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
|
from vllm.model_executor.models.utils import maybe_prefix
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
|
|
class CustomDeepSeekShareHead(SharedHead):
|
|
|
|
def __init__(self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "") -> None:
|
|
nn.Module.__init__(self)
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.head = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "head"))
|
|
|
|
|
|
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
prefix: str,
|
|
model_config: ModelConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
nn.Module.__init__(self)
|
|
vllm_config = get_current_vllm_config()
|
|
|
|
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
|
config.hidden_size,
|
|
bias=False)
|
|
self.shared_head = CustomDeepSeekShareHead(config=config,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(
|
|
prefix, "shared_head"))
|
|
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
|
|
prefix=prefix)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_cache: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
previous_hidden_states: torch.Tensor,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
spec_step_index: int = 0,
|
|
) -> torch.Tensor:
|
|
assert inputs_embeds is not None
|
|
# masking inputs at position 0, as not needed by MTP
|
|
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
|
torch.zeros_like(inputs_embeds),
|
|
inputs_embeds)
|
|
inputs_embeds = self.enorm(inputs_embeds)
|
|
previous_hidden_states = self.hnorm(previous_hidden_states)
|
|
|
|
hidden_states = self.eh_proj(
|
|
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
|
|
|
hidden_states, residual = self.mtp_block(positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=None)
|
|
hidden_states = residual + hidden_states
|
|
return hidden_states
|
|
|
|
|
|
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
config = vllm_config.model_config.hf_config
|
|
self.mtp_start_layer_idx = config.num_hidden_layers
|
|
self.num_mtp_layers = config.num_nextn_predict_layers
|
|
# to map the exact layer index from weights
|
|
self.layers = torch.nn.ModuleDict({
|
|
str(idx):
|
|
CustomDeepSeekMultiTokenPredictorLayer(
|
|
config,
|
|
f"{prefix}.layers.{idx}",
|
|
model_config=vllm_config.model_config,
|
|
cache_config=vllm_config.cache_config,
|
|
quant_config=vllm_config.quant_config,
|
|
)
|
|
for idx in range(self.mtp_start_layer_idx,
|
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
|
})
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
|
|
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
|
self.layers_list = [
|
|
self.layers[str(idx)]
|
|
for idx in range(self.mtp_start_layer_idx,
|
|
self.mtp_start_layer_idx + self.num_mtp_layers)
|
|
]
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: torch.Tensor,
|
|
attn_metadata: AttentionMetadata,
|
|
previous_hidden_states: torch.Tensor,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
spec_step_idx: int = 0,
|
|
) -> torch.Tensor:
|
|
if inputs_embeds is None:
|
|
inputs_embeds = self.embed_tokens(input_ids)
|
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
|
step_kv_cache = kv_caches[
|
|
current_step_idx] if kv_caches is not None else None
|
|
return self.layers_list[current_step_idx](
|
|
input_ids,
|
|
positions,
|
|
step_kv_cache,
|
|
attn_metadata,
|
|
previous_hidden_states,
|
|
inputs_embeds,
|
|
current_step_idx,
|
|
)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata=None, # type: ignore
|
|
spec_step_idx: int = 0,
|
|
) -> torch.Tensor:
|
|
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
|
mtp_layer = self.layers_list[current_step_idx]
|
|
logits = self.logits_processor(mtp_layer.shared_head.head,
|
|
mtp_layer.shared_head(hidden_states),
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
|
|
class CustomDeepSeekMTP(DeepSeekMTP):
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
self.config = vllm_config.model_config.hf_config
|
|
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(
|
|
prefix, "model"))
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
spec_step_idx: int = 0,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata, previous_hidden_states,
|
|
inputs_embeds, spec_step_idx)
|
|
return hidden_states
|