mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432)
### What this PR does / why we need it? Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch. The final goal is to remove all the patches and align the code arch to vllm, thus we need to do the following work in next prs. TODO: - [x] remove patch on attention spec - [ ] refactor the kvcache creation logic ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? 1. CI passed with existing test. 2. Test pass with deepseek-v3.2-exp - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@ -18,6 +18,7 @@ class TestNPUWorker(TestBase):
|
|||||||
self.model_config_mock = MagicMock(spec=ModelConfig)
|
self.model_config_mock = MagicMock(spec=ModelConfig)
|
||||||
self.model_config_mock.dtype = torch.float16
|
self.model_config_mock.dtype = torch.float16
|
||||||
self.model_config_mock.trust_remote_code = False
|
self.model_config_mock.trust_remote_code = False
|
||||||
|
self.model_config_mock.hf_config = None
|
||||||
|
|
||||||
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
|
self.parallel_config_mock = MagicMock(spec=ParallelConfig)
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ def register():
|
|||||||
|
|
||||||
|
|
||||||
def register_model():
|
def register_model():
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
|
||||||
|
|
||||||
from .models import register_model
|
from .models import register_model
|
||||||
register_model()
|
register_model()
|
||||||
|
@ -34,8 +34,6 @@ class AscendConfig:
|
|||||||
|
|
||||||
def __init__(self, vllm_config):
|
def __init__(self, vllm_config):
|
||||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||||
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
|
|
||||||
self.use_sfa = self.is_deepseek_sfa
|
|
||||||
|
|
||||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||||
{})
|
{})
|
||||||
|
@ -510,7 +510,6 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
self.enable_prefetch = ascend_config.enable_prefetch
|
|
||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
@ -690,6 +689,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
topk_indices = self.indexer_select(hidden_states_decode,
|
topk_indices = self.indexer_select(hidden_states_decode,
|
||||||
decode_q_c,
|
decode_q_c,
|
||||||
attn_metadata=attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
kv_cache=kv_cache)
|
kv_cache=kv_cache)
|
||||||
|
|
||||||
query_states = (decode_q_nope, decode_q_pe)
|
query_states = (decode_q_nope, decode_q_pe)
|
||||||
@ -778,6 +779,8 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
||||||
qr=prefill_qr,
|
qr=prefill_qr,
|
||||||
kv_cache=kv_cache,
|
kv_cache=kv_cache,
|
||||||
|
cos=cos,
|
||||||
|
sin=sin,
|
||||||
attn_metadata=attn_metadata)
|
attn_metadata=attn_metadata)
|
||||||
query_states = (prefill_q_nope, prefill_q_pe)
|
query_states = (prefill_q_nope, prefill_q_pe)
|
||||||
key_states = (prefill_k_nope, prefill_k_pe)
|
key_states = (prefill_k_nope, prefill_k_pe)
|
||||||
@ -920,17 +923,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
qr: torch.Tensor,
|
qr: torch.Tensor,
|
||||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
):
|
):
|
||||||
if attn_metadata.prefill is not None:
|
if attn_metadata.prefill is not None:
|
||||||
cos = attn_metadata.prefill.cos
|
|
||||||
sin = attn_metadata.prefill.sin
|
|
||||||
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
||||||
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
||||||
block_table = attn_metadata.prefill.block_table
|
block_table = attn_metadata.prefill.block_table
|
||||||
elif attn_metadata.decode is not None:
|
elif attn_metadata.decode is not None:
|
||||||
cos = attn_metadata.decode.cos
|
|
||||||
sin = attn_metadata.decode.sin
|
|
||||||
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
||||||
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
||||||
block_table = attn_metadata.decode.block_table
|
block_table = attn_metadata.decode.block_table
|
||||||
|
@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||||
first_kv_cache_tuple) == 2
|
first_kv_cache_tuple) == 2
|
||||||
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
self.use_sparse: bool = len(first_kv_cache_tuple) == 3
|
||||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||||
# MHA case. [2 (k and v), num_blocks, ...]
|
# MHA case. [2 (k and v), num_blocks, ...]
|
||||||
@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||||
)
|
)
|
||||||
elif self.use_sfa:
|
elif self.use_sparse:
|
||||||
cache_k_normed_addr_list = []
|
cache_k_normed_addr_list = []
|
||||||
cache_k_pe_addr_list = []
|
cache_k_pe_addr_list = []
|
||||||
cache_k_idx_addr_list = []
|
cache_k_idx_addr_list = []
|
||||||
@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||||
)
|
)
|
||||||
elif self.use_sfa:
|
elif self.use_sparse:
|
||||||
remote_cache_key_k_normed = BlocksCacheKey(
|
remote_cache_key_k_normed = BlocksCacheKey(
|
||||||
cluster_id=remote_cluster_id, model_id=0)
|
cluster_id=remote_cluster_id, model_id=0)
|
||||||
remote_cache_key_k_pe = BlocksCacheKey(
|
remote_cache_key_k_pe = BlocksCacheKey(
|
||||||
|
@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
# TODO(jianzs): find a better way to detect MLA.
|
# TODO(jianzs): find a better way to detect MLA.
|
||||||
self.use_mla = len(block_len) == 2
|
self.use_mla = len(block_len) == 2
|
||||||
self.use_sfa = len(block_len) == 3
|
self.use_sparse = len(block_len) == 3
|
||||||
|
|
||||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||||
self.executor = ThreadPoolExecutor(max_workers=32)
|
self.executor = ThreadPoolExecutor(max_workers=32)
|
||||||
@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
block_len = (self.block_len[k % 2])
|
block_len = (self.block_len[k % 2])
|
||||||
elif self.use_sfa:
|
elif self.use_sparse:
|
||||||
block_len = (self.block_len[k % 3])
|
block_len = (self.block_len[k % 3])
|
||||||
else:
|
else:
|
||||||
block_len = (self.block_len[0])
|
block_len = (self.block_len[0])
|
||||||
@ -850,7 +850,8 @@ class MooncakeConnectorScheduler:
|
|||||||
assert "tp_size" in decode_parallel_config.keys()
|
assert "tp_size" in decode_parallel_config.keys()
|
||||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||||
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
if self.vllm_config.model_config.use_mla or hasattr(
|
||||||
|
self.vllm_config.model_config.hf_config, "index_topk"):
|
||||||
num_need_pulls = 1
|
num_need_pulls = 1
|
||||||
else:
|
else:
|
||||||
num_p_block_heads = max(
|
num_p_block_heads = max(
|
||||||
@ -942,7 +943,7 @@ class MooncakeConnectorWorker:
|
|||||||
# kv_transfer variables
|
# kv_transfer variables
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||||
self.num_need_pulls = 1
|
self.num_need_pulls = 1
|
||||||
else:
|
else:
|
||||||
num_d_block_heads = max(1,
|
num_d_block_heads = max(1,
|
||||||
@ -995,7 +996,7 @@ class MooncakeConnectorWorker:
|
|||||||
self.use_mla = first_kv_cache_tuple[0].size(
|
self.use_mla = first_kv_cache_tuple[0].size(
|
||||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||||
first_kv_cache_tuple) == 2
|
first_kv_cache_tuple) == 2
|
||||||
self.use_sfa = len(first_kv_cache_tuple) == 3
|
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||||
elif self.use_sfa:
|
elif self.use_sparse:
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
block_rank = 3 # [block_size, latent_dim]
|
||||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||||
@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker:
|
|||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||||
block_shape)
|
block_shape)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||||
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
||||||
|
|
||||||
self.kv_caches = kv_caches
|
self.kv_caches = kv_caches
|
||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker:
|
|||||||
region_len = self.num_blocks * self.block_len[i % 2]
|
region_len = self.num_blocks * self.block_len[i % 2]
|
||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
self._register(base_addr, region_len)
|
self._register(base_addr, region_len)
|
||||||
elif self.use_sfa:
|
elif self.use_sparse:
|
||||||
for i, cache in enumerate(cache_or_caches, 0):
|
for i, cache in enumerate(cache_or_caches, 0):
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[i % 3]
|
region_len = self.num_blocks * self.block_len[i % 3]
|
||||||
@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker:
|
|||||||
else:
|
else:
|
||||||
cache_list = [
|
cache_list = [
|
||||||
cache_or_caches
|
cache_or_caches
|
||||||
] if self.use_mla or self.use_sfa else cache_or_caches
|
] if self.use_mla or self.use_sparse else cache_or_caches
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
region_len = self.num_blocks * self.block_len[0]
|
||||||
@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker:
|
|||||||
sampled_nums = []
|
sampled_nums = []
|
||||||
ori_data = np.arange(self._prefill_tp_size)
|
ori_data = np.arange(self._prefill_tp_size)
|
||||||
# random split prefill tp list
|
# random split prefill tp list
|
||||||
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||||
# use deepseek mla, num_key_value_heads == 128, but consider as 1
|
# use deepseek mla, num_key_value_heads == 128, but consider as 1
|
||||||
if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa:
|
if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse:
|
||||||
num_kv_head = 1
|
num_kv_head = 1
|
||||||
else:
|
else:
|
||||||
num_kv_head = self.num_key_value_heads
|
num_kv_head = self.num_key_value_heads
|
||||||
@ -1279,4 +1280,4 @@ def ensure_zmq_recv(
|
|||||||
logger.error(f"Receive failed after all retries: {e}")
|
logger.error(f"Receive failed after all retries: {e}")
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to receive data after {max_retries} "
|
f"Failed to receive data after {max_retries} "
|
||||||
f"retries: {e}")
|
f"retries: {e}")
|
||||||
|
@ -31,6 +31,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.compilation.decorators import support_torch_compile
|
||||||
from vllm.config import CacheConfig, VllmConfig
|
from vllm.config import CacheConfig, VllmConfig
|
||||||
from vllm.distributed import (divide, get_pp_group,
|
from vllm.distributed import (divide, get_pp_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
@ -47,7 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|||||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
default_weight_loader, maybe_remap_kv_scale_name)
|
default_weight_loader, maybe_remap_kv_scale_name)
|
||||||
from vllm.model_executor.models.deepseek_v2 import \
|
from vllm.model_executor.models.deepseek_v2 import \
|
||||||
@ -56,10 +58,11 @@ from vllm.model_executor.models.deepseek_v2 import (
|
|||||||
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
|
||||||
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
|
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
|
||||||
get_spec_layer_idx_from_weight_name)
|
get_spec_layer_idx_from_weight_name)
|
||||||
from vllm.model_executor.models.utils import (PPMissingLayer,
|
from vllm.model_executor.models.utils import (
|
||||||
is_pp_missing_parameter,
|
PPMissingLayer, is_pp_missing_parameter,
|
||||||
maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||||
@ -69,6 +72,53 @@ from vllm_ascend.ops.common_fused_moe import AscendFusedMoE
|
|||||||
from vllm_ascend.ops.linear import AscendLinearBase
|
from vllm_ascend.ops.linear import AscendLinearBase
|
||||||
|
|
||||||
|
|
||||||
|
@support_torch_compile
|
||||||
|
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
|
# Rewrite this init func mainly for removing cuda-hard code
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
config = vllm_config.model_config.hf_config
|
||||||
|
quant_config = vllm_config.quant_config
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.vocab_size = config.vocab_size
|
||||||
|
self.is_v32 = hasattr(config, "index_topk")
|
||||||
|
if self.is_v32:
|
||||||
|
topk_tokens = config.index_topk
|
||||||
|
topk_indices_buffer = torch.empty(
|
||||||
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
||||||
|
topk_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=current_platform.device_type)
|
||||||
|
else:
|
||||||
|
topk_indices_buffer = None
|
||||||
|
|
||||||
|
if get_pp_group().is_first_rank:
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.embed_tokens")
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||||
|
config.num_hidden_layers,
|
||||||
|
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
|
||||||
|
topk_indices_buffer),
|
||||||
|
prefix=f"{prefix}.layers")
|
||||||
|
|
||||||
|
if get_pp_group().is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer()
|
||||||
|
self.make_empty_intermediate_tensors = (
|
||||||
|
make_empty_intermediate_tensors_factory(
|
||||||
|
["hidden_states", "residual"], config.hidden_size))
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -270,6 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
self.indexer = None
|
||||||
|
|
||||||
mla_modules = AscendMLAModules(
|
mla_modules = AscendMLAModules(
|
||||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||||
@ -281,6 +332,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
kv_b_proj=self.kv_b_proj,
|
kv_b_proj=self.kv_b_proj,
|
||||||
o_proj=self.o_proj,
|
o_proj=self.o_proj,
|
||||||
rotary_emb=self.rotary_emb,
|
rotary_emb=self.rotary_emb,
|
||||||
|
indexer=None,
|
||||||
|
is_sparse=hasattr(config, "index_topk"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.mla_attn = MultiHeadLatentAttention(
|
self.mla_attn = MultiHeadLatentAttention(
|
||||||
@ -499,7 +552,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
ascend_config = get_ascend_config()
|
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@ -515,7 +567,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
# TODO: enable mla in vllm-ascend
|
# TODO: enable mla in vllm-ascend
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
if ascend_config.use_sfa:
|
if hasattr(model_config.hf_config, "index_topk"):
|
||||||
attn_cls = CustomDeepseekV2SFAAttention
|
attn_cls = CustomDeepseekV2SFAAttention
|
||||||
else:
|
else:
|
||||||
attn_cls = CustomDeepseekV2MLAAttention
|
attn_cls = CustomDeepseekV2MLAAttention
|
||||||
@ -590,8 +642,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
"kv_a_proj_with_mqa",
|
"kv_a_proj_with_mqa",
|
||||||
]
|
]
|
||||||
|
|
||||||
self.model = DeepseekV2Model(vllm_config=vllm_config,
|
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
|
||||||
prefix=maybe_prefix(prefix, "model"))
|
prefix=maybe_prefix(
|
||||||
|
prefix, "model"))
|
||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
|
@ -42,6 +42,8 @@ class AscendMLAModules:
|
|||||||
kv_b_proj: torch.nn.Module
|
kv_b_proj: torch.nn.Module
|
||||||
o_proj: torch.nn.Module
|
o_proj: torch.nn.Module
|
||||||
rotary_emb: torch.nn.Module
|
rotary_emb: torch.nn.Module
|
||||||
|
indexer: Optional[torch.nn.Module]
|
||||||
|
is_sparse: bool
|
||||||
|
|
||||||
|
|
||||||
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||||
|
@ -94,7 +94,7 @@ class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_mla=True,
|
use_mla=True,
|
||||||
use_sfa=True,
|
use_sparse=True,
|
||||||
# SFA Args
|
# SFA Args
|
||||||
q_lora_rank=self.q_lora_rank,
|
q_lora_rank=self.q_lora_rank,
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
@ -18,4 +18,3 @@
|
|||||||
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
||||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
|
||||||
|
@ -21,8 +21,6 @@ if HAS_TRITON:
|
|||||||
import vllm_ascend.patch.worker.patch_common.patch_triton
|
import vllm_ascend.patch.worker.patch_common.patch_triton
|
||||||
|
|
||||||
# isort: off
|
# isort: off
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa
|
||||||
|
@ -1,188 +0,0 @@
|
|||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import vllm
|
|
||||||
import vllm.envs as envs
|
|
||||||
from torch import nn
|
|
||||||
from vllm.attention import Attention, AttentionType, get_attn_backend
|
|
||||||
from vllm.attention.backends.abstract import AttentionBackend
|
|
||||||
from vllm.attention.selector import backend_name_to_enum
|
|
||||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
|
||||||
from vllm.config import CacheConfig, get_current_vllm_config
|
|
||||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import \
|
|
||||||
QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
|
|
||||||
"""Attention layer.
|
|
||||||
|
|
||||||
This class takes query, key, and value tensors as input. The input tensors
|
|
||||||
can either contain prompt tokens or generation tokens.
|
|
||||||
The class does the following:
|
|
||||||
|
|
||||||
1. Store the input key and value tensors in the KV cache.
|
|
||||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
|
||||||
3. Return the output tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_heads: int,
|
|
||||||
head_size: int,
|
|
||||||
scale: float,
|
|
||||||
num_kv_heads: Optional[int] = None,
|
|
||||||
alibi_slopes: Optional[List[float]] = None,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
logits_soft_cap: Optional[float] = None,
|
|
||||||
per_layer_sliding_window: Optional[int] = None,
|
|
||||||
use_mla: bool = False,
|
|
||||||
use_sfa: bool = False,
|
|
||||||
prefix: str = "",
|
|
||||||
attn_type: str = AttentionType.DECODER,
|
|
||||||
kv_sharing_target_layer_name: Optional[str] = None,
|
|
||||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
|
||||||
**extra_impl_args,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
The KV cache is stored inside this class and is accessed via
|
|
||||||
`self.kv_cache`.
|
|
||||||
"""
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
AttentionLayerBase.__init__(self)
|
|
||||||
|
|
||||||
if per_layer_sliding_window is not None:
|
|
||||||
# per-layer sliding window
|
|
||||||
sliding_window = per_layer_sliding_window
|
|
||||||
elif cache_config is not None:
|
|
||||||
# model-level sliding window
|
|
||||||
sliding_window = cache_config.sliding_window
|
|
||||||
else:
|
|
||||||
sliding_window = None
|
|
||||||
|
|
||||||
if cache_config is not None:
|
|
||||||
kv_cache_dtype = cache_config.cache_dtype
|
|
||||||
block_size = cache_config.block_size
|
|
||||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
|
||||||
else:
|
|
||||||
kv_cache_dtype = "auto"
|
|
||||||
block_size = 16
|
|
||||||
calculate_kv_scales = False
|
|
||||||
if num_kv_heads is None:
|
|
||||||
num_kv_heads = num_heads
|
|
||||||
assert num_heads % num_kv_heads == 0, \
|
|
||||||
f"num_heads ({num_heads}) is not " \
|
|
||||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
|
||||||
|
|
||||||
# The default k/v_scale is set to 1.0. This is ignored
|
|
||||||
# when kv-cache is not fp8, and should be used with
|
|
||||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
|
||||||
# expect the pre-quantized k/v_scale to be loaded along
|
|
||||||
# with the model weights.
|
|
||||||
self.kv_cache_dtype = kv_cache_dtype
|
|
||||||
self.calculate_kv_scales = calculate_kv_scales
|
|
||||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
# FlashAttn doesn't support quantizing the kv-cache only
|
|
||||||
# but requires q to be quantized as well.
|
|
||||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
|
||||||
|
|
||||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
|
||||||
# backends that require the scales to be on host instead of on device.
|
|
||||||
# e.g. Flashinfer
|
|
||||||
self._q_scale_float = 1.0
|
|
||||||
self._k_scale_float = 1.0
|
|
||||||
self._v_scale_float = 1.0
|
|
||||||
|
|
||||||
# The output scale on host memory. This should be the input scale of
|
|
||||||
# the quant op after this attention layer.
|
|
||||||
self._o_scale_float: Optional[float] = None
|
|
||||||
|
|
||||||
self.use_mla = use_mla
|
|
||||||
self.num_heads = num_heads
|
|
||||||
self.head_size = head_size
|
|
||||||
self.num_kv_heads = num_kv_heads
|
|
||||||
self.sliding_window = sliding_window
|
|
||||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
|
||||||
|
|
||||||
quant_method = quant_config.get_quant_method(
|
|
||||||
self, prefix=prefix) if quant_config else None
|
|
||||||
if quant_method is not None and not isinstance(
|
|
||||||
quant_method, UnquantizedLinearMethod):
|
|
||||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
|
||||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
|
||||||
# checkpoint config and become the "auto" behavior
|
|
||||||
if self.kv_cache_dtype == "fp8_e5m2":
|
|
||||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
|
||||||
"fp8 checkpoints.")
|
|
||||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
|
||||||
# parameters so that it can be loaded from the model checkpoint.
|
|
||||||
# The k/v_scale will then be converted back to native float32
|
|
||||||
# values after weight loading.
|
|
||||||
self.quant_method = quant_method
|
|
||||||
self.quant_method.create_weights(self)
|
|
||||||
|
|
||||||
# During model initialization, the default dtype is set as the model
|
|
||||||
# weight and activation dtype.
|
|
||||||
dtype = torch.get_default_dtype()
|
|
||||||
if attn_backend is None:
|
|
||||||
self.attn_backend = get_attn_backend(head_size,
|
|
||||||
dtype,
|
|
||||||
kv_cache_dtype,
|
|
||||||
block_size,
|
|
||||||
use_mla=use_mla,
|
|
||||||
use_sfa=use_sfa,
|
|
||||||
has_sink=self.has_sink)
|
|
||||||
else:
|
|
||||||
self.attn_backend = attn_backend
|
|
||||||
|
|
||||||
impl_cls = self.attn_backend.get_impl_cls()
|
|
||||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
|
||||||
logits_soft_cap, attn_type,
|
|
||||||
kv_sharing_target_layer_name, **extra_impl_args)
|
|
||||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
|
||||||
self.dtype = dtype
|
|
||||||
|
|
||||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
|
||||||
# torch.compile works by registering the attention as one giant
|
|
||||||
# opaque custom op. For other platforms, we directly call them
|
|
||||||
# and let torch.compile handle them.
|
|
||||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
|
||||||
|
|
||||||
self.use_output = self.attn_backend.accept_output_buffer
|
|
||||||
compilation_config = get_current_vllm_config().compilation_config
|
|
||||||
if prefix in compilation_config.static_forward_context:
|
|
||||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
||||||
compilation_config.static_forward_context[prefix] = self
|
|
||||||
self.layer_name = prefix
|
|
||||||
self.attn_type = attn_type
|
|
||||||
|
|
||||||
if kv_sharing_target_layer_name is not None:
|
|
||||||
validate_kv_sharing_target(
|
|
||||||
prefix,
|
|
||||||
kv_sharing_target_layer_name,
|
|
||||||
compilation_config.static_forward_context,
|
|
||||||
)
|
|
||||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
|
||||||
|
|
||||||
# use a placeholder kv cache tensor during init, which will be replaced
|
|
||||||
# by bind_kv_cache
|
|
||||||
# this variable will not be accessed if use_direct_call is True
|
|
||||||
self.kv_cache = [
|
|
||||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
|
||||||
).parallel_config.pipeline_parallel_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
|
||||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
|
||||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
|
||||||
self.query_quant = None
|
|
||||||
|
|
||||||
|
|
||||||
vllm.attention.Attention = AscendAttention
|
|
@ -164,6 +164,9 @@ class NPUPlatform(Platform):
|
|||||||
"kv_cache_dtype", None)
|
"kv_cache_dtype", None)
|
||||||
if kv_cache_dtype is not None:
|
if kv_cache_dtype is not None:
|
||||||
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
vllm_config.cache_config.cache_dtype = kv_cache_dtype
|
||||||
|
elif model_config and hasattr(model_config.hf_config, "index_topk"):
|
||||||
|
vllm_config.cache_config.cache_dtype = str(
|
||||||
|
model_config.dtype).replace("torch.", "")
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
logger.warning("Model config is missing. This may indicate "
|
logger.warning("Model config is missing. This may indicate "
|
||||||
"that we are running a test case")
|
"that we are running a test case")
|
||||||
@ -284,25 +287,27 @@ class NPUPlatform(Platform):
|
|||||||
vllm_config.scheduler_config = ascend_scheduler_config
|
vllm_config.scheduler_config = ascend_scheduler_config
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls,
|
def get_attn_backend_cls(
|
||||||
selected_backend,
|
cls,
|
||||||
head_size,
|
selected_backend,
|
||||||
dtype,
|
head_size,
|
||||||
kv_cache_dtype,
|
dtype,
|
||||||
block_size,
|
kv_cache_dtype,
|
||||||
use_v1,
|
block_size,
|
||||||
use_mla,
|
use_v1,
|
||||||
use_sfa,
|
use_mla,
|
||||||
has_sink=False):
|
has_sink=False,
|
||||||
|
use_sparse=False,
|
||||||
|
):
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||||
if use_mla and not use_sfa:
|
if use_mla and not use_sparse:
|
||||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||||
if use_mla and use_sfa:
|
if use_mla and use_sparse:
|
||||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||||
|
|
||||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||||
@ -321,7 +326,7 @@ class NPUPlatform(Platform):
|
|||||||
(True, True, True):
|
(True, True, True):
|
||||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||||
}
|
}
|
||||||
return backend_map[(use_mla, use_sfa, use_torchair)]
|
return backend_map[(use_mla, use_sparse, use_torchair)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
@ -183,6 +183,11 @@ packed_modules_model_mapping = {
|
|||||||
"experts":
|
"experts":
|
||||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||||
},
|
},
|
||||||
|
"deepseek_v32": {
|
||||||
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||||
|
"experts":
|
||||||
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||||
|
},
|
||||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||||
|
@ -67,6 +67,8 @@ class MtpProposer(Proposer):
|
|||||||
1,
|
1,
|
||||||
device=self.runner.device,
|
device=self.runner.device,
|
||||||
dtype=torch.int32)
|
dtype=torch.int32)
|
||||||
|
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||||
|
"index_topk")
|
||||||
|
|
||||||
def load_model(self, model) -> None:
|
def load_model(self, model) -> None:
|
||||||
loader = get_model_loader(self.vllm_config.load_config)
|
loader = get_model_loader(self.vllm_config.load_config)
|
||||||
@ -613,7 +615,7 @@ class MtpProposer(Proposer):
|
|||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
self.torchair_compiled_model = torch.compile(
|
self.torchair_compiled_model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
dynamic=not get_ascend_config().use_sfa,
|
dynamic=not self.use_sparse,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
backend=npu_backend)
|
backend=npu_backend)
|
||||||
return self.torchair_compiled_model
|
return self.torchair_compiled_model
|
||||||
@ -636,7 +638,7 @@ class MtpProposer(Proposer):
|
|||||||
self.torchair_compiled_models[
|
self.torchair_compiled_models[
|
||||||
batch_size] = torchair.inference.cache_compile(
|
batch_size] = torchair.inference.cache_compile(
|
||||||
self.model.__dict__[forward_proxy_name],
|
self.model.__dict__[forward_proxy_name],
|
||||||
dynamic=not get_ascend_config().use_sfa,
|
dynamic=not self.use_sparse,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
cache_dir=TORCHAIR_CACHE_DIR,
|
cache_dir=TORCHAIR_CACHE_DIR,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -791,7 +791,7 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.attn",
|
prefix=f"{prefix}.attn",
|
||||||
use_mla=True,
|
use_mla=True,
|
||||||
use_sfa=True,
|
use_sparse=True,
|
||||||
# SFA Args
|
# SFA Args
|
||||||
q_lora_rank=self.q_lora_rank,
|
q_lora_rank=self.q_lora_rank,
|
||||||
kv_lora_rank=self.kv_lora_rank,
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
@ -879,12 +879,12 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.use_mla = False
|
self.use_mla = False
|
||||||
self.use_sfa = False
|
self.use_sparse = False
|
||||||
# TODO: enable mla in vllm-ascend
|
# TODO: enable mla in vllm-ascend
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
if ascend_config.use_sfa:
|
if hasattr(model_config.hf_config, "index_topk"):
|
||||||
attn_cls = TorchairDeepseekV2SFAAttention
|
attn_cls = TorchairDeepseekV2SFAAttention
|
||||||
self.use_sfa = True
|
self.use_sparse = True
|
||||||
else:
|
else:
|
||||||
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
|
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
|
||||||
self.use_mla = True
|
self.use_mla = True
|
||||||
@ -950,7 +950,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if attn_metadata is not None:
|
if attn_metadata is not None:
|
||||||
decoding_condition_met = (
|
decoding_condition_met = (
|
||||||
not attn_metadata.is_prefill if self.use_sfa else
|
not attn_metadata.is_prefill if self.use_sparse else
|
||||||
not forward_context.with_prefill if self.use_mla else False)
|
not forward_context.with_prefill if self.use_mla else False)
|
||||||
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
||||||
else:
|
else:
|
||||||
|
@ -376,7 +376,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
self.torchair_compiled_model = torch.compile(
|
self.torchair_compiled_model = torch.compile(
|
||||||
self.model,
|
self.model,
|
||||||
dynamic=not self.ascend_config.use_sfa,
|
dynamic=not self.use_sparse,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
backend=npu_backend)
|
backend=npu_backend)
|
||||||
return self.torchair_compiled_model
|
return self.torchair_compiled_model
|
||||||
@ -399,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
self.torchair_compiled_models[
|
self.torchair_compiled_models[
|
||||||
batch_size] = torchair.inference.cache_compile(
|
batch_size] = torchair.inference.cache_compile(
|
||||||
self.model.__dict__[forward_proxy_name],
|
self.model.__dict__[forward_proxy_name],
|
||||||
dynamic=not self.ascend_config.use_sfa,
|
dynamic=not self.use_sparse,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
cache_dir=TORCHAIR_CACHE_DIR,
|
cache_dir=TORCHAIR_CACHE_DIR,
|
||||||
config=config,
|
config=config,
|
||||||
|
@ -738,7 +738,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
self.enable_prefetch = ascend_config.enable_prefetch
|
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
if ascend_config.torchair_graph_config.enabled:
|
if ascend_config.torchair_graph_config.enabled:
|
||||||
self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[
|
self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[
|
||||||
|
@ -309,13 +309,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
# Set up Attention
|
# Set up Attention
|
||||||
self.attn_backend = get_attn_backend(
|
self.use_sparse = hasattr(self.vllm_config.model_config.hf_config,
|
||||||
0,
|
"index_topk")
|
||||||
self.dtype,
|
self.attn_backend = get_attn_backend(0,
|
||||||
None,
|
self.dtype,
|
||||||
self.block_size,
|
None,
|
||||||
use_mla=self.model_config.use_mla,
|
self.block_size,
|
||||||
use_sfa=self.ascend_config.use_sfa)
|
use_mla=self.model_config.use_mla,
|
||||||
|
use_sparse=self.use_sparse)
|
||||||
if torch.version.cann.startswith("8.3"):
|
if torch.version.cann.startswith("8.3"):
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(
|
self.attn_mask_builder = AttentionMaskBuilder(
|
||||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||||
@ -871,7 +872,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||||
# Chunk Prefill situation.
|
# Chunk Prefill situation.
|
||||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse:
|
||||||
if torch.version.cann.startswith("8.3"):
|
if torch.version.cann.startswith("8.3"):
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
else:
|
else:
|
||||||
@ -1507,7 +1508,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model=self.get_model(),
|
model=self.get_model(),
|
||||||
**extra_attn_metadata_args)
|
**extra_attn_metadata_args)
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
if self.vllm_config.model_config.use_mla or self.use_sparse:
|
||||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
@ -2655,7 +2656,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
|
|
||||||
if self.ascend_config.is_deepseek_sfa:
|
if self.use_sparse:
|
||||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||||
kv_cache_config)
|
kv_cache_config)
|
||||||
elif self.model_config.is_deepseek_mla:
|
elif self.model_config.is_deepseek_mla:
|
||||||
@ -2699,7 +2700,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
elif hasattr(
|
elif hasattr(
|
||||||
attn_backend, "get_supported_block_size"
|
attn_backend, "get_supported_block_size"
|
||||||
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
) and not self.model_config.is_deepseek_mla and not self.use_sparse:
|
||||||
block_size = attn_backend.get_supported_block_size()[0]
|
block_size = attn_backend.get_supported_block_size()[0]
|
||||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
@ -3245,7 +3246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
use_sfa = self.ascend_config.use_sfa
|
use_sparse = self.use_sparse
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
for layer_name, attn_module in attn_layers.items():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
@ -3267,7 +3268,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
# TODO(lucas): move the attention specs into the model layers like
|
# TODO(lucas): move the attention specs into the model layers like
|
||||||
# the attention backends
|
# the attention backends
|
||||||
if attn_module.attn_type == AttentionType.DECODER:
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
if use_mla and not use_sfa:
|
if use_mla and not use_sparse:
|
||||||
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
kv_cache_spec[layer_name] = MLAAttentionSpec(
|
||||||
block_size=block_size,
|
block_size=block_size,
|
||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
from vllm_ascend.ascend_config import init_ascend_config
|
||||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
@ -88,7 +88,11 @@ class NPUWorker(WorkerBase):
|
|||||||
# init ascend config and soc version
|
# init ascend config and soc version
|
||||||
init_ascend_config(vllm_config)
|
init_ascend_config(vllm_config)
|
||||||
init_ascend_soc_version()
|
init_ascend_soc_version()
|
||||||
if get_ascend_config().use_sfa:
|
use_sparse = False
|
||||||
|
if vllm_config.model_config is not None:
|
||||||
|
use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||||
|
"index_topk")
|
||||||
|
if use_sparse:
|
||||||
# Direct import instead of using try_register_lib to ensure proper error handling when
|
# Direct import instead of using try_register_lib to ensure proper error handling when
|
||||||
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
|
Reference in New Issue
Block a user