[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)

### What this PR does / why we need it?
Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch.

The final goal is to remove all the patches and align the code arch to
vllm, thus we need to do the following work in next prs.
TODO:
- [x] remove patch on attention spec
- [ ] refactor the kvcache creation logic

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
1. CI passed with existing test.
2. Test pass with deepseek-v3.2-exp


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2025-10-15 17:48:58 +08:00
committed by GitHub
parent 099255e933
commit 8abe517870
20 changed files with 143 additions and 262 deletions

View File

@ -18,6 +18,7 @@ class TestNPUWorker(TestBase):
self.model_config_mock = MagicMock(spec=ModelConfig)
self.model_config_mock.dtype = torch.float16
self.model_config_mock.trust_remote_code = False
self.model_config_mock.hf_config = None
self.parallel_config_mock = MagicMock(spec=ParallelConfig)

View File

@ -23,7 +23,6 @@ def register():
def register_model():
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
from .models import register_model
register_model()

View File

@ -34,8 +34,6 @@ class AscendConfig:
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
self.use_sfa = self.is_deepseek_sfa
torchair_graph_config = additional_config.get("torchair_graph_config",
{})

View File

@ -510,7 +510,6 @@ class AscendSFAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
@ -690,6 +689,8 @@ class AscendSFAImpl(MLAAttentionImpl):
topk_indices = self.indexer_select(hidden_states_decode,
decode_q_c,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
kv_cache=kv_cache)
query_states = (decode_q_nope, decode_q_pe)
@ -778,6 +779,8 @@ class AscendSFAImpl(MLAAttentionImpl):
topk_indices = self.indexer_select(x=hidden_states_prefill,
qr=prefill_qr,
kv_cache=kv_cache,
cos=cos,
sin=sin,
attn_metadata=attn_metadata)
query_states = (prefill_q_nope, prefill_q_pe)
key_states = (prefill_k_nope, prefill_k_pe)
@ -920,17 +923,15 @@ class AscendSFAImpl(MLAAttentionImpl):
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
cos,
sin,
attn_metadata: M,
):
if attn_metadata.prefill is not None:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
actual_seq_lengths_query = attn_metadata.prefill.query_lens
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
block_table = attn_metadata.prefill.block_table
elif attn_metadata.decode is not None:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
actual_seq_lengths_key = attn_metadata.decode.seq_lens
block_table = attn_metadata.decode.block_table

View File

@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker():
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
elif self.use_sfa:
elif self.use_sparse:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
cache_k_idx_addr_list = []
@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sfa:
elif self.use_sparse:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(

View File

@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.use_sparse = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
self.executor = ThreadPoolExecutor(max_workers=32)
@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread):
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
elif self.use_sparse:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
@ -850,7 +850,8 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.use_mla or hasattr(
self.vllm_config.model_config.hf_config, "index_topk"):
num_need_pulls = 1
else:
num_p_block_heads = max(
@ -942,7 +943,7 @@ class MooncakeConnectorWorker:
# kv_transfer variables
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
self.num_need_pulls = 1
else:
num_d_block_heads = max(1,
@ -995,7 +996,7 @@ class MooncakeConnectorWorker:
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
self.use_sparse = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
elif self.use_sparse:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
self.use_mla, self.use_sparse, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker:
else:
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
] if self.use_mla or self.use_sparse else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker:
sampled_nums = []
ori_data = np.arange(self._prefill_tp_size)
# random split prefill tp list
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
# use deepseek mla, num_key_value_heads == 128, but consider as 1
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
num_kv_head = 1
else:
num_kv_head = self.num_key_value_heads

View File

@ -31,6 +31,7 @@ import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
@ -47,7 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
@ -56,10 +58,11 @@ from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (PPMissingLayer,
is_pp_missing_parameter,
maybe_prefix)
from vllm.model_executor.models.utils import (
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
@ -69,6 +72,53 @@ from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
@support_torch_compile
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Rewrite this init func mainly for removing cuda-hard code
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
if self.is_v32:
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
else:
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
@ -270,6 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.indexer = None
mla_modules = AscendMLAModules(
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
@ -281,6 +332,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=None,
is_sparse=hasattr(config, "index_topk"),
)
self.mla_attn = MultiHeadLatentAttention(
@ -499,7 +552,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
ascend_config = get_ascend_config()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@ -515,7 +567,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if ascend_config.use_sfa:
if hasattr(model_config.hf_config, "index_topk"):
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
@ -590,8 +642,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
"kv_a_proj_with_mqa",
]
self.model = DeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,

View File

@ -42,6 +42,8 @@ class AscendMLAModules:
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: Optional[torch.nn.Module]
is_sparse: bool
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):

View File

@ -94,7 +94,7 @@ class AscendSparseFlashAttention(MultiHeadLatentAttention):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
use_sparse=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,

View File

@ -18,4 +18,3 @@
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa

View File

@ -21,8 +21,6 @@ if HAS_TRITON:
import vllm_ascend.patch.worker.patch_common.patch_triton
# isort: off
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa

View File

@ -1,188 +0,0 @@
from typing import List, Optional
import torch
import vllm
import vllm.envs as envs
from torch import nn
from vllm.attention import Attention, AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sfa: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
nn.Module.__init__(self)
AttentionLayerBase.__init__(self)
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=self.has_sink)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
self.query_quant = None
vllm.attention.Attention = AscendAttention

View File

@ -164,6 +164,9 @@ class NPUPlatform(Platform):
"kv_cache_dtype", None)
if kv_cache_dtype is not None:
vllm_config.cache_config.cache_dtype = kv_cache_dtype
elif model_config and hasattr(model_config.hf_config, "index_topk"):
vllm_config.cache_config.cache_dtype = str(
model_config.dtype).replace("torch.", "")
if model_config is None:
logger.warning("Model config is missing. This may indicate "
"that we are running a test case")
@ -284,7 +287,8 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config = ascend_scheduler_config
@classmethod
def get_attn_backend_cls(cls,
def get_attn_backend_cls(
cls,
selected_backend,
head_size,
dtype,
@ -292,17 +296,18 @@ class NPUPlatform(Platform):
block_size,
use_v1,
use_mla,
use_sfa,
has_sink=False):
has_sink=False,
use_sparse=False,
):
if not use_v1:
raise ValueError("vLLM Ascend does not support V0 engine.")
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
if use_mla and not use_sfa:
if use_mla and not use_sparse:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sfa:
if use_mla and use_sparse:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
use_torchair = ascend_config.torchair_graph_config.enabled
@ -321,7 +326,7 @@ class NPUPlatform(Platform):
(True, True, True):
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
}
return backend_map[(use_mla, use_sfa, use_torchair)]
return backend_map[(use_mla, use_sparse, use_torchair)]
@classmethod
def get_punica_wrapper(cls) -> str:

View File

@ -183,6 +183,11 @@ packed_modules_model_mapping = {
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
"deepseek_v32": {
"gate_up_proj": ["gate_proj", "up_proj"],
"experts":
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
},
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
# NOTE 2.The description file generated by the current msmodelslim tool does not have
# MTP layer info. Please manually add it and set the value to FLOAT.

View File

@ -67,6 +67,8 @@ class MtpProposer(Proposer):
1,
device=self.runner.device,
dtype=torch.int32)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)
@ -613,7 +615,7 @@ class MtpProposer(Proposer):
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not get_ascend_config().use_sfa,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
@ -636,7 +638,7 @@ class MtpProposer(Proposer):
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not get_ascend_config().use_sfa,
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,

View File

@ -791,7 +791,7 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
use_sparse=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
@ -879,12 +879,12 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
self.tp_rank = get_tp_group().rank_in_group
ascend_config = get_ascend_config()
self.use_mla = False
self.use_sfa = False
self.use_sparse = False
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
if ascend_config.use_sfa:
if hasattr(model_config.hf_config, "index_topk"):
attn_cls = TorchairDeepseekV2SFAAttention
self.use_sfa = True
self.use_sparse = True
else:
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
self.use_mla = True
@ -950,7 +950,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
forward_context = get_forward_context()
if attn_metadata is not None:
decoding_condition_met = (
not attn_metadata.is_prefill if self.use_sfa else
not attn_metadata.is_prefill if self.use_sparse else
not forward_context.with_prefill if self.use_mla else False)
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
else:

View File

@ -376,7 +376,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.ascend_config.use_sfa,
dynamic=not self.use_sparse,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
@ -399,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=not self.ascend_config.use_sfa,
dynamic=not self.use_sparse,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,

View File

@ -738,7 +738,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
if ascend_config.torchair_graph_config.enabled:
self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[

View File

@ -309,13 +309,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=self.dtype,
device=self.device)
# Set up Attention
self.attn_backend = get_attn_backend(
0,
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
"index_topk")
self.attn_backend = get_attn_backend(0,
self.dtype,
None,
self.block_size,
use_mla=self.model_config.use_mla,
use_sfa=self.ascend_config.use_sfa)
use_sparse=self.use_sparse)
if torch.version.cann.startswith("8.3"):
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
@ -871,7 +872,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
return self.attn_mask_builder.get_pooling_mask(self.device)
# Chunk Prefill situation.
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
@ -1507,7 +1508,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model=self.get_model(),
**extra_attn_metadata_args)
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
if self.vllm_config.model_config.use_mla or self.use_sparse:
attn_metadata_i.num_input_tokens = num_input_tokens
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@ -2655,7 +2656,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.may_add_encoder_only_layers_to_kv_cache_config()
self.initialize_attn_backend(kv_cache_config)
if self.ascend_config.is_deepseek_sfa:
if self.use_sparse:
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
kv_cache_config)
elif self.model_config.is_deepseek_mla:
@ -2699,7 +2700,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
elif hasattr(
attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
@ -3245,7 +3246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
use_sfa = self.ascend_config.use_sfa
use_sparse = self.use_sparse
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
@ -3267,7 +3268,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# TODO(lucas): move the attention specs into the model layers like
# the attention backends
if attn_module.attn_type == AttentionType.DECODER:
if use_mla and not use_sfa:
if use_mla and not use_sparse:
kv_cache_spec[layer_name] = MLAAttentionSpec(
block_size=block_size,
num_kv_heads=attn_module.num_kv_heads,

View File

@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
from vllm.v1.worker.worker_base import WorkerBase
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
@ -88,7 +88,11 @@ class NPUWorker(WorkerBase):
# init ascend config and soc version
init_ascend_config(vllm_config)
init_ascend_soc_version()
if get_ascend_config().use_sfa:
use_sparse = False
if vllm_config.model_config is not None:
use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
if use_sparse:
# Direct import instead of using try_register_lib to ensure proper error handling when
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
# yapf: disable