mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[KVCache][Bugfix] Fix kv cache initialization error of attention layer (#3113)
### What this PR does / why we need it?
Fixes #3096
1. Fix kv cache initialization error of attention layer. There are some
models with layer name like `attn.attn`, instead of `self_attn`, but the
initialization of kv cache tensors only check for `self_attn` and
`attn.attn`, which leding to the error `AssertionError: Some layers are
not correctly initialized`
2. Set the default value of input arg `sampling_metadata` in
`compute_logits` for the modeling files in vllm-ascend. Thus fixing the
error `Qwen3NextForCausalLM.compute_logits() missing 1 required
positional argument: 'sampling_metadata'`
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
test locally with internlm
- vLLM version: v0.10.2
- vLLM main:
5aeb925452
---------
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@ -166,7 +166,7 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata, # type: ignore
|
||||
sampling_metadata=None, # type: ignore
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
|
@ -986,7 +986,7 @@ class Qwen3NextForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata, # type: ignore
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
@ -344,7 +344,7 @@ class CustomQwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata, # type: ignore
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
@ -170,7 +170,7 @@ class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata, # type: ignore
|
||||
sampling_metadata=None, # type: ignore
|
||||
spec_step_idx: int = 0,
|
||||
) -> torch.Tensor:
|
||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
||||
|
@ -936,7 +936,7 @@ class PanguProMoEForCausalLM(nn.Module, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata, # type: ignore
|
||||
sampling_metadata=None, # type: ignore
|
||||
) -> Optional[torch.Tensor]:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata)
|
||||
|
@ -2784,9 +2784,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for idx in range(len(kv_cache_tensor.shared_by)):
|
||||
layer_name = kv_cache_tensor.shared_by[idx]
|
||||
if "linear_attn" in layer_name:
|
||||
# for mamba linear attention
|
||||
for layer_name_inner in kv_cache_tensor.shared_by:
|
||||
if "self_attn" in layer_name_inner or layer_name_inner in kv_cache_raw_tensors.keys(
|
||||
):
|
||||
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
|
||||
layer_name_inner in kv_cache_raw_tensors.keys():
|
||||
continue
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
tensor = torch.zeros(kv_cache_tensor.size,
|
||||
@ -2800,7 +2801,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
tensor = self._align_memory(
|
||||
tensor, alignment)[:kv_cache_tensor.size]
|
||||
kv_cache_raw_tensors[layer_name_inner] = tensor
|
||||
elif "self_attn" in layer_name:
|
||||
elif "attn" in layer_name:
|
||||
# for other attentions, e.g., self_attn, sliding window attn
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
|
||||
dtype=torch.int8,
|
||||
|
Reference in New Issue
Block a user