mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
1017 lines
44 KiB
Python
1017 lines
44 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
|
# and OPT implementations in this library. It has been modified from its
|
|
# original forms to accommodate minor architectural differences compared
|
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# # Adapted from
|
|
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
|
|
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
|
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
|
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
|
|
|
from typing import Any, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch_npu # noqa: F401
|
|
import vllm.envs as envs
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
from vllm.attention import Attention, AttentionMetadata
|
|
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
|
from vllm.distributed import (get_pp_group,
|
|
get_tensor_model_parallel_world_size,
|
|
get_tp_group, tensor_model_parallel_all_reduce)
|
|
from vllm.distributed.parallel_state import get_dp_group
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
ReplicatedLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.sampler import get_sampler
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.models.deepseek_v2 import \
|
|
DeepseekV2ForCausalLM # ruff: noqa: E501
|
|
from vllm.model_executor.models.deepseek_v2 import \
|
|
yarn_get_mscale # ruff: noqa: E501
|
|
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
|
DeepseekV2DecoderLayer,
|
|
DeepseekV2MLAAttention)
|
|
from vllm.model_executor.models.utils import (
|
|
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix)
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
|
|
from vllm_ascend.multistream.base import MSEventKey
|
|
from vllm_ascend.multistream.context import (
|
|
advance_step_multistream_layer_context, get_multistream_comm_context,
|
|
get_multistream_layer_context, set_multistream_context)
|
|
from vllm_ascend.multistream.layers import (MultiStreamPostTransformerLayer,
|
|
MultiStreamPreTransformerLayer)
|
|
from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
|
MultiStreamStepMetadata,
|
|
make_multistream_metadata_ds)
|
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
|
from vllm_ascend.utils import dispose_tensor
|
|
|
|
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
|
|
|
|
|
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
|
|
|
|
def _forward_ms_mlp(self, x):
|
|
current_ms_metadata = get_multistream_comm_context()
|
|
assert current_ms_metadata is not None
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
current_ms_metadata.before_comm_event.record()
|
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
|
current_ms_metadata.before_comm_event.wait()
|
|
x, _ = self.down_proj(x)
|
|
current_ms_metadata.after_comm_event.record()
|
|
return x
|
|
|
|
|
|
class CustomDeepseekDBOMoE(nn.Module):
|
|
|
|
top_k: int
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
self.n_shared_experts = config.n_shared_experts
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
if self.tp_size > config.n_routed_experts:
|
|
raise ValueError(
|
|
f"Tensor parallel size {self.tp_size} is greater than "
|
|
f"the number of experts {config.n_routed_experts}.")
|
|
|
|
if config.hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
|
|
self.gate = ReplicatedLinear(config.hidden_size,
|
|
config.n_routed_experts,
|
|
bias=False,
|
|
quant_config=None,
|
|
prefix=f"{prefix}.gate")
|
|
if config.topk_method == "noaux_tc":
|
|
self.gate.e_score_correction_bias = nn.Parameter(
|
|
torch.empty(config.n_routed_experts))
|
|
else:
|
|
self.gate.e_score_correction_bias = None
|
|
|
|
self.experts = AscendFusedMoE(
|
|
num_experts=config.n_routed_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
reduce_results=False,
|
|
renormalize=config.norm_topk_prob,
|
|
quant_config=quant_config,
|
|
use_grouped_topk=True,
|
|
num_expert_group=config.n_group,
|
|
topk_group=config.topk_group,
|
|
prefix=f"{prefix}.experts",
|
|
scoring_func=config.scoring_func,
|
|
e_score_correction_bias=self.gate.e_score_correction_bias)
|
|
|
|
if config.n_shared_experts is not None:
|
|
intermediate_size = (config.moe_intermediate_size *
|
|
config.n_shared_experts)
|
|
self.shared_experts = CustomDeepseekDBOMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
reduce_results=True,
|
|
prefix=f"{prefix}.shared_experts",
|
|
)
|
|
CustomDeepseekDBOMoE.top_k = config.num_experts_per_tok
|
|
|
|
self.dp_size = get_dp_group().world_size
|
|
|
|
self.tp_group = get_tp_group().device_group
|
|
self.tp_rank = get_tp_group().rank_in_group
|
|
|
|
self.params_dtype = torch.get_default_dtype()
|
|
|
|
ascend_config = get_ascend_config()
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
|
if attn_metadata is None:
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
# when profile runs, force experts to load balanced tokens
|
|
# to avoid high memory consumption on a single rank.
|
|
# TODO: need a better flag to indicate whether in profile run or not.
|
|
if attn_metadata is None:
|
|
# for profile run
|
|
is_prefill = True
|
|
enable_force_load_balance = True
|
|
else:
|
|
is_prefill = attn_metadata.num_prefills > 0
|
|
enable_force_load_balance = False
|
|
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
|
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
|
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
|
|
old_hidden_states = hidden_states.clone()
|
|
|
|
if self.tp_size > 1:
|
|
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
|
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
|
|
hidden_states = chunks[self.tp_rank]
|
|
elif not self.torchair_graph_enabled:
|
|
num_padding_tokens = (self.tp_size -
|
|
num_tokens % self.tp_size) % self.tp_size
|
|
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
|
|
if num_padding_tokens > 0:
|
|
hidden_states = nn.functional.pad(
|
|
hidden_states, (0, 0, 0, num_padding_tokens))
|
|
chunk_hidden_states = torch.tensor_split(hidden_states,
|
|
self.tp_size,
|
|
dim=0)
|
|
hidden_states = chunk_hidden_states[self.tp_rank]
|
|
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits, _ = self.gate(hidden_states)
|
|
|
|
hidden_states = self.experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
is_prefill=is_prefill,
|
|
top_k=CustomDeepseekDBOMoE.top_k,
|
|
enable_force_load_balance=enable_force_load_balance,
|
|
) * self.routed_scaling_factor
|
|
|
|
if self.tp_size > 1:
|
|
if self.torchair_graph_enabled:
|
|
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
|
final_hidden_states = torch.zeros(
|
|
[num_tokens, hidden_size],
|
|
dtype=self.params_dtype,
|
|
device="npu")
|
|
dist.all_gather_into_tensor(final_hidden_states,
|
|
hidden_states, self.tp_group)
|
|
hidden_states = final_hidden_states
|
|
else:
|
|
hidden_states = tensor_model_parallel_all_reduce(
|
|
hidden_states)
|
|
else:
|
|
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
|
self.tp_group)
|
|
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
|
if num_padding_tokens > 0:
|
|
hidden_states = hidden_states[:-num_padding_tokens]
|
|
|
|
if self.n_shared_experts is not None:
|
|
shared_output = self.shared_experts(old_hidden_states)
|
|
|
|
if shared_output is not None:
|
|
hidden_states = hidden_states + shared_output
|
|
|
|
return hidden_states.view(num_tokens, hidden_size)
|
|
|
|
# ----------------------------------------- TBO-related --------------------------------------------
|
|
def _forward_ms_op_shared_expert(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
):
|
|
shared_output = self.shared_experts._forward_ms_mlp(hidden_states)
|
|
return shared_output
|
|
|
|
def _forward_ms_op_gate(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
):
|
|
# router_logits: (num_tokens, n_experts)
|
|
router_logits, _ = self.gate(hidden_states)
|
|
return router_logits
|
|
|
|
def _forward_ms_op_tp_allgather(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
chunk_hidden_states: torch.Tensor,
|
|
num_tokens: int = 0,
|
|
):
|
|
current_ms_metadata = get_multistream_comm_context()
|
|
if current_ms_metadata is None:
|
|
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
|
self.tp_group)
|
|
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
|
if num_tokens > 0:
|
|
final_hidden_states = final_hidden_states[:-num_tokens]
|
|
else:
|
|
current_ms_metadata.before_comm_event.record()
|
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
|
current_ms_metadata.before_comm_event.wait()
|
|
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
|
self.tp_group)
|
|
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
|
if num_tokens > 0:
|
|
final_hidden_states = final_hidden_states[:-num_tokens]
|
|
current_ms_metadata.after_comm_event.record()
|
|
return final_hidden_states
|
|
|
|
|
|
class CustomDeepseekDBOMLAAttention(DeepseekV2MLAAttention):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
qk_nope_head_dim: int,
|
|
qk_rope_head_dim: int,
|
|
v_head_dim: int,
|
|
q_lora_rank: Optional[int],
|
|
kv_lora_rank: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
nn.Module.__init__(self)
|
|
self.hidden_size = hidden_size
|
|
self.qk_nope_head_dim = qk_nope_head_dim
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
|
self.v_head_dim = v_head_dim
|
|
|
|
self.q_lora_rank = q_lora_rank
|
|
self.kv_lora_rank = kv_lora_rank
|
|
|
|
self.num_heads = num_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert num_heads % tp_size == 0
|
|
self.num_local_heads = num_heads // tp_size
|
|
|
|
self.scaling = self.qk_head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
if self.q_lora_rank is not None:
|
|
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
|
self.q_lora_rank,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_a_proj")
|
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
|
eps=config.rms_norm_eps)
|
|
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
|
self.num_heads *
|
|
self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_b_proj")
|
|
else:
|
|
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
|
self.num_heads *
|
|
self.qk_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.q_proj")
|
|
|
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
|
self.hidden_size,
|
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
|
eps=config.rms_norm_eps)
|
|
self.kv_b_proj = ColumnParallelLinear(
|
|
self.kv_lora_rank,
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.kv_b_proj")
|
|
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.o_proj")
|
|
|
|
if rope_scaling:
|
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
|
rotary_dim=qk_rope_head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
is_neox_style=False)
|
|
if rope_scaling:
|
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
|
scaling_factor = rope_scaling["factor"]
|
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
|
self.scaling = self.scaling * mscale * mscale
|
|
|
|
# In the MLA backend, kv_cache includes both k_c and
|
|
# pe (i.e. decoupled position embeddings). In particular,
|
|
# the concat_and_cache_mla op requires
|
|
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
|
|
# i.e.
|
|
# kv_lora_rank + qk_rope_head_dim == head_size
|
|
self.mla_attn = Attention(
|
|
num_heads=self.num_local_heads,
|
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
|
scale=self.scaling,
|
|
num_kv_heads=1,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
use_mla=True,
|
|
# MLA Args
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.kv_lora_rank,
|
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
|
qk_head_dim=self.qk_head_dim,
|
|
v_head_dim=self.v_head_dim,
|
|
rotary_emb=self.rotary_emb,
|
|
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
|
kv_a_layernorm=self.kv_a_layernorm,
|
|
kv_b_proj=self.kv_b_proj,
|
|
o_proj=self.o_proj,
|
|
)
|
|
|
|
self.prefix = prefix
|
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
|
|
|
ascend_config = get_ascend_config()
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
|
if self.q_lora_rank is not None:
|
|
ckq = self.q_a_proj(hidden_states)[0]
|
|
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
|
else:
|
|
hidden_states_or_q_c = hidden_states
|
|
if self.torchair_graph_enabled:
|
|
forward_kwargs = {}
|
|
if envs.VLLM_USE_V1:
|
|
output_shape = hidden_states.shape
|
|
output = torch.empty(output_shape,
|
|
dtype=hidden_states_or_q_c.dtype,
|
|
device=hidden_states_or_q_c.device)
|
|
forward_kwargs['output'] = output
|
|
|
|
output = self.mla_attn.impl.forward(self.mla_attn,
|
|
hidden_states_or_q_c,
|
|
hidden_states, None, kv_cache,
|
|
attn_metadata,
|
|
**forward_kwargs)
|
|
if envs.VLLM_USE_V1:
|
|
output = output.view(-1, output_shape[-1])
|
|
return output
|
|
else:
|
|
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
|
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
|
return self.mla_attn(hidden_states_or_q_c,
|
|
kv_c_normed,
|
|
k_pe,
|
|
output_shape=hidden_states.shape)
|
|
|
|
|
|
class CustomDeepseekDBODecoderLayer(DeepseekV2DecoderLayer):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
prefix: str,
|
|
model_config: ModelConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
) -> None:
|
|
nn.Module.__init__(self)
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
8192)
|
|
# DecoderLayers are created with `make_layers` which passes the prefix
|
|
# with the layer's index.
|
|
layer_idx = int(prefix.split(sep='.')[-1])
|
|
self.layer_idx = layer_idx
|
|
# TODO: enable mla in vllm-ascend
|
|
if model_config.use_mla:
|
|
attn_cls = CustomDeepseekDBOMLAAttention
|
|
else:
|
|
attn_cls = DeepseekV2Attention
|
|
self.self_attn = attn_cls(
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
v_head_dim=config.v_head_dim,
|
|
q_lora_rank=config.q_lora_rank
|
|
if hasattr(config, "q_lora_rank") else None,
|
|
kv_lora_rank=config.kv_lora_rank,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.self_attn",
|
|
)
|
|
|
|
if (config.n_routed_experts is not None
|
|
and layer_idx >= config.first_k_dense_replace
|
|
and layer_idx % config.moe_layer_freq == 0):
|
|
self.mlp = CustomDeepseekDBOMoE(
|
|
config=config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
else:
|
|
self.mlp = CustomDeepseekDBOMLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.mlp",
|
|
)
|
|
self.input_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.routed_scaling_factor = config.routed_scaling_factor
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
) -> torch.Tensor:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
previous_hidden_states, previous_residual = hidden_states, residual
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
# Dispose hidden_states and residual from the previous layer
|
|
# to save npu memory because they're no longer used.
|
|
dispose_tensor(previous_hidden_states)
|
|
dispose_tensor(previous_residual)
|
|
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
|
|
if hidden_states.dtype == torch.float16:
|
|
# Fix FP16 overflow
|
|
# We scale both hidden_states and residual before
|
|
# rmsnorm, and rmsnorm result would not affect by scale.
|
|
hidden_states *= 1. / self.routed_scaling_factor
|
|
if self.layer_idx == 0:
|
|
# The residual is shared by all layers, we only scale it on
|
|
# first layer.
|
|
residual *= 1. / self.routed_scaling_factor
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
|
|
if isinstance(self.mlp, CustomDeepseekDBOMoE):
|
|
hidden_states = self.mlp(hidden_states, attn_metadata)
|
|
else:
|
|
hidden_states = self.mlp(hidden_states)
|
|
|
|
if isinstance(
|
|
self.mlp,
|
|
CustomDeepseekDBOMLP) and hidden_states.dtype == torch.float16:
|
|
# Fix FP16 overflow
|
|
# Scaling the DeepseekV2MLP output, it is the input of
|
|
# input_layernorm of next decoder layer.
|
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
|
# of DeepseekV2MOE
|
|
hidden_states *= 1. / self.routed_scaling_factor
|
|
|
|
return hidden_states, residual
|
|
|
|
# ----------------------------------------- TBO-related --------------------------------------------
|
|
def _forward_ms_layer(
|
|
self,
|
|
positions: List[torch.Tensor],
|
|
hidden_states: List[torch.Tensor],
|
|
residual: List[torch.Tensor],
|
|
attn_metadata: List[AttentionMetadata],
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = False,
|
|
) -> tuple[List[torch.Tensor], List[torch.Tensor]]:
|
|
layer_index, ms_metadata, _ = get_multistream_layer_context()
|
|
assert layer_index >= 0 and ms_metadata is not None
|
|
num_micro_batchs = ms_metadata.ms_config.num_micro_batches
|
|
assert isinstance(self.mlp, CustomDeepseekDBOMoE)
|
|
assert len(positions) == num_micro_batchs
|
|
assert len(hidden_states) == num_micro_batchs
|
|
assert residual is not None
|
|
assert attn_metadata is not None
|
|
num_tokens = []
|
|
hidden_dims = []
|
|
shared_outputs = []
|
|
router_logits = []
|
|
chunk_hidden_states = []
|
|
|
|
# block 1 : attention
|
|
# block 2 : attn tp communication
|
|
# the attn computation of microbatch 1 can be overlapped with the moe
|
|
# communication in the previous layer, and the attn computation of microbatch 2
|
|
# can be overlapped with the attn communication of microbatch 1
|
|
for i in range(num_micro_batchs):
|
|
# wait last layer moe finishing communication
|
|
ms_metadata.try_wait_event(layer_index - 1, i,
|
|
MSEventKey.FFN_AR_FINISH)
|
|
context = MultiStreamStepMetadata(
|
|
comm_stream=ms_metadata.communicate_stream,
|
|
before_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.ATTN_COM_FINISH],
|
|
after_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.ATTN_AR_FINISH],
|
|
)
|
|
|
|
with set_multistream_context(context, i):
|
|
forward_context = get_forward_context()
|
|
forward_context.attn_metadata = attn_metadata[i]
|
|
|
|
# input layernorm
|
|
hidden_states[i], residual[
|
|
i] = self._forward_ms_op_input_layernorm(
|
|
hidden_states[i], residual[i])
|
|
# attention and tp allreduce
|
|
hidden_states[i], residual[i] = self._forward_ms_op_attn(
|
|
positions[i], hidden_states[i], residual[i], kv_cache,
|
|
attn_metadata[i])
|
|
|
|
# block 3 : shared experts
|
|
# if there is an allreduce ops in shared expert, we can overlap it with the computation of the
|
|
# shared expert for next microbatch or moe gating
|
|
for i in range(num_micro_batchs):
|
|
ms_metadata.try_wait_event(layer_index, i,
|
|
MSEventKey.ATTN_AR_FINISH)
|
|
context = MultiStreamStepMetadata(
|
|
comm_stream=ms_metadata.communicate_stream,
|
|
before_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.MOE_SE_COMP_FINISH],
|
|
after_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.MOE_SE_COMM_FINISH],
|
|
)
|
|
with set_multistream_context(context, i):
|
|
# compute shared expert after finishing ATTN AR
|
|
hidden_states[i], residual[
|
|
i] = self._forward_ms_op_post_attn_layernorm(
|
|
hidden_states[i], residual[i])
|
|
|
|
num_token, hidden_dim = hidden_states[i].shape
|
|
hidden_states[i] = hidden_states[i].view(-1, hidden_dim)
|
|
num_tokens.append(num_token)
|
|
hidden_dims.append(hidden_dim)
|
|
if self.mlp.n_shared_experts is not None:
|
|
# TODO: we can move shared expert computation into next block if reduce results is false
|
|
shared_output = self.mlp._forward_ms_op_shared_expert(
|
|
hidden_states[i])
|
|
shared_outputs.append(shared_output)
|
|
|
|
# block 4 : moe
|
|
for i in range(num_micro_batchs):
|
|
# when profile runs, force experts to load balanced tokens
|
|
# to avoid high memory consumption on a single rank.
|
|
# TODO: need a better flag to indicate whether in profile run or not.
|
|
if attn_metadata[i] is None:
|
|
# for profile run
|
|
is_prefill = True
|
|
enable_force_load_balance = True
|
|
else:
|
|
is_prefill = attn_metadata[i].num_prefills > 0
|
|
enable_force_load_balance = False
|
|
|
|
if self.mlp.tp_size > 1:
|
|
num_token, _ = hidden_states[i].shape
|
|
padded_num_tokens = (self.mlp.tp_size - num_token %
|
|
self.mlp.tp_size) % self.mlp.tp_size
|
|
if padded_num_tokens > 0:
|
|
hidden_states[i] = nn.functional.pad(
|
|
hidden_states[i], (0, 0, 0, padded_num_tokens))
|
|
chunk_hidden_state = torch.tensor_split(hidden_states[i],
|
|
self.mlp.tp_size,
|
|
dim=0)
|
|
chunk_hidden_states.append(chunk_hidden_state)
|
|
local_hidden_states = chunk_hidden_state[self.mlp.tp_rank]
|
|
else:
|
|
local_hidden_states = hidden_states[i]
|
|
|
|
router_logit = self.mlp._forward_ms_op_gate(local_hidden_states)
|
|
router_logits.append(router_logit)
|
|
|
|
if CustomDeepseekDBOMoE.top_k:
|
|
real_top_k = CustomDeepseekDBOMoE.top_k
|
|
else:
|
|
real_top_k = self.mlp.experts.top_k
|
|
|
|
hidden_states[i] = self.mlp.experts._forward_ms_fused_moe_comp(
|
|
local_hidden_states, router_logits[i], is_prefill, real_top_k,
|
|
enable_force_load_balance)
|
|
|
|
# the following kernels will be submitted to the comm stream to overlap the computation of the
|
|
# moe computation of next microbatch and the attn computation of next layer
|
|
context = MultiStreamStepMetadata(
|
|
comm_stream=ms_metadata.communicate_stream,
|
|
before_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.FFN_COM_FINISH],
|
|
after_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.MOE_AFTER_COMM],
|
|
)
|
|
context.before_comm_event.record()
|
|
with torch.npu.stream(ms_metadata.communicate_stream):
|
|
context.before_comm_event.wait()
|
|
if self.mlp.experts.reduce_results and (
|
|
self.mlp.experts.tp_size > 1
|
|
or self.mlp.experts.ep_size > 1):
|
|
hidden_states[i] = tensor_model_parallel_all_reduce(
|
|
hidden_states[i])
|
|
hidden_states[
|
|
i] = hidden_states[i] * self.mlp.routed_scaling_factor
|
|
context.after_comm_event.record()
|
|
|
|
context = MultiStreamStepMetadata(
|
|
comm_stream=ms_metadata.communicate_stream,
|
|
before_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.MOE_AFTER_COMM],
|
|
after_comm_event=ms_metadata.ms_events[layer_index][i][
|
|
MSEventKey.FFN_AR_FINISH],
|
|
)
|
|
with set_multistream_context(context, i):
|
|
if self.mlp.tp_size > 1:
|
|
hidden_states[i] = self.mlp._forward_ms_op_tp_allgather(
|
|
hidden_states[i], chunk_hidden_states[i],
|
|
padded_num_tokens)
|
|
with torch.npu.stream(ms_metadata.communicate_stream):
|
|
# last
|
|
if shared_outputs[i] is not None:
|
|
hidden_states[i] = hidden_states[i] + shared_outputs[i]
|
|
hidden_states[i] = hidden_states[i].view(
|
|
num_tokens[i], hidden_dims[i])
|
|
if isinstance(self.mlp, CustomDeepseekDBOMLP
|
|
) and hidden_states[i].dtype == torch.float16:
|
|
# Fix FP16 overflow
|
|
# Scaling the DeepseekV2MLP output, it is the input of
|
|
# input_layernorm of next decoder layer.
|
|
# The scaling of DeepseekV2MOE output would be done in the forward
|
|
# of DeepseekV2MOE
|
|
hidden_states[i] *= 1. / self.routed_scaling_factor
|
|
context.after_comm_event.record()
|
|
return hidden_states, residual
|
|
|
|
# should split ops in Decoder Layer
|
|
def _forward_ms_op_input_layernorm(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.input_layernorm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.input_layernorm(
|
|
hidden_states, residual)
|
|
return hidden_states, residual
|
|
|
|
def _forward_ms_op_attn(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
kv_cache: Optional[torch.Tensor] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
hidden_states = self.self_attn(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
kv_cache=kv_cache,
|
|
attn_metadata=attn_metadata,
|
|
)
|
|
if hidden_states.dtype == torch.float16:
|
|
# Fix FP16 overflow
|
|
# We scale both hidden_states and residual before
|
|
# rmsnorm, and rmsnorm result would not affect by scale.
|
|
hidden_states *= 1. / self.routed_scaling_factor
|
|
if self.layer_idx == 0:
|
|
# The residual is shared by all layers, we only scale it on
|
|
# first layer.
|
|
residual *= 1. / self.routed_scaling_factor
|
|
return hidden_states, residual
|
|
|
|
def _forward_ms_op_post_attn_layernorm(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
):
|
|
hidden_states, residual = self.post_attention_layernorm(
|
|
hidden_states, residual)
|
|
return hidden_states, residual
|
|
|
|
|
|
class CustomDeepseekDBOModel(nn.Module):
|
|
|
|
fall_back_to_pt_during_load = False
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
model_config = vllm_config.model_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.padding_idx = config.pad_token_id
|
|
self.vocab_size = config.vocab_size
|
|
self.first_k_dense_replace = config.first_k_dense_replace
|
|
|
|
if get_pp_group().is_first_rank:
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.embed_tokens")
|
|
else:
|
|
self.embed_tokens = PPMissingLayer()
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: CustomDeepseekDBODecoderLayer(
|
|
config,
|
|
prefix,
|
|
model_config=model_config,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
),
|
|
prefix=f"{prefix}.layers")
|
|
|
|
if get_pp_group().is_last_rank:
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
else:
|
|
self.norm = PPMissingLayer()
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size))
|
|
|
|
# tbo related members
|
|
if VLLM_ASCEND_ENABLE_DBO:
|
|
self.use_mla = model_config.use_mla
|
|
self.multistream_config = MultiStreamConfig()
|
|
multistream_metadata = make_multistream_metadata_ds(
|
|
start_layer=self.start_layer + self.first_k_dense_replace,
|
|
end_layer=self.end_layer,
|
|
causal_lm=getattr(config, "causal_lm", True),
|
|
multistream_config=self.multistream_config,
|
|
)
|
|
self.ms_pre_layer = MultiStreamPreTransformerLayer(
|
|
multistream_metadata)
|
|
self.ms_post_layer = MultiStreamPostTransformerLayer(
|
|
multistream_metadata)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
|
|
num_normal_layers = (self.first_k_dense_replace
|
|
if VLLM_ASCEND_ENABLE_DBO and self.can_run_ms()
|
|
else self.end_layer - self.start_layer)
|
|
|
|
for i in range(self.start_layer, self.start_layer + num_normal_layers):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer(
|
|
positions, hidden_states, residual,
|
|
kv_caches[i -
|
|
self.start_layer] if kv_caches is not None else None,
|
|
attn_metadata)
|
|
|
|
moe_start_layer = self.start_layer + num_normal_layers
|
|
if moe_start_layer != self.end_layer:
|
|
# if we enable multistream/dbo, process sparse layers here
|
|
hidden_states, residual = self._forward_ms_layers(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
moe_start_layer=moe_start_layer,
|
|
kv_caches=kv_caches,
|
|
)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
def can_run_ms(self):
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
# support mla attention and V1 engine at present
|
|
if not self.use_mla or not envs.VLLM_USE_V1:
|
|
return False
|
|
# enable prefill overlap
|
|
if attn_metadata is None or attn_metadata.num_prefills == 0:
|
|
return False
|
|
else:
|
|
[token_index, seq_index
|
|
] = compute_split_seq_index(attn_metadata.query_lens,
|
|
attn_metadata.attn_state,
|
|
attn_metadata.num_decode_tokens)
|
|
if token_index == 0 or seq_index == 0 or seq_index == len(
|
|
attn_metadata.query_lens):
|
|
return False
|
|
# check whether the total tokens exceed the threshold
|
|
if self.multistream_config is None or attn_metadata.num_actual_tokens < self.multistream_config.min_total_tokens_to_split:
|
|
return False
|
|
return True
|
|
|
|
def _forward_ms_layers(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
moe_start_layer: int,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
is_prefill: bool = False,
|
|
):
|
|
|
|
if moe_start_layer == self.end_layer:
|
|
return hidden_states, residual
|
|
|
|
attn_metadata, [positions, hidden_states,
|
|
residual] = self.ms_pre_layer(
|
|
[positions, hidden_states, residual], )
|
|
# the rest layers
|
|
for i in range(moe_start_layer, self.end_layer):
|
|
layer = self.layers[i]
|
|
hidden_states, residual = layer._forward_ms_layer(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
residual=residual,
|
|
attn_metadata=attn_metadata,
|
|
kv_cache=kv_caches[i - self.start_layer]
|
|
if kv_caches is not None else None,
|
|
is_prefill=is_prefill)
|
|
advance_step_multistream_layer_context()
|
|
|
|
[hidden_states,
|
|
residual] = self.ms_post_layer([hidden_states, residual], )
|
|
return hidden_states, residual
|
|
|
|
|
|
class CustomDeepseekDBOForCausalLM(DeepseekV2ForCausalLM):
|
|
# add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging
|
|
packed_modules_mapping = {
|
|
"gate_up_proj": ["gate_proj", "up_proj"],
|
|
"experts":
|
|
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
|
}
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
nn.Module.__init__(self)
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.model = CustomDeepseekDBOModel(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(
|
|
prefix, "model"))
|
|
if get_pp_group().is_last_rank:
|
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config)
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.sampler = get_sampler()
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
attn_metadata: Optional[AttentionMetadata] = None,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
attn_metadata, intermediate_tensors,
|
|
inputs_embeds)
|
|
return hidden_states
|