mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
### What this PR does / why we need it? Add model basic accuracy test(Qwen2.5-0.5B-Instruct) Signed-off-by: hfadzxy <starmoon_zhang@163.com>
331 lines
12 KiB
Python
331 lines
12 KiB
Python
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
# Adapted from vllm/tests/kernels/test_moe.py
|
|
|
|
from typing import Callable, Optional
|
|
|
|
import torch
|
|
import torch_npu
|
|
from vllm.model_executor.layers.fused_moe.layer import \
|
|
UnquantizedFusedMoEMethod
|
|
|
|
|
|
def fused_experts(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
top_k: int,
|
|
expert_map: torch.Tensor = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Fused experts with top-k routing.
|
|
|
|
Args:
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
|
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
|
topk_weights: Routing weights of shape (num_tokens, top_k).
|
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
|
top_k: Number of experts to select.
|
|
expert_map: Expert mapping of shape (num_experts,).
|
|
|
|
Returns:
|
|
hidden_states: Hidden states after routing.
|
|
"""
|
|
# Check constraints.
|
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
|
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
|
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
|
|
|
|
original_shape = hidden_states.shape
|
|
assert len(original_shape) == 2
|
|
|
|
num_tokens = hidden_states.shape[:-1].numel()
|
|
num_experts = w1.shape[0]
|
|
dtype = hidden_states.dtype
|
|
device = hidden_states.device
|
|
assert dtype in [torch.float32, torch.float16, torch.bfloat16
|
|
], "Only float32, float16, and bfloat16 are supported"
|
|
|
|
if expert_map is not None:
|
|
# Generate token indices and flatten
|
|
token_indices = (torch.arange(num_tokens,
|
|
device=device,
|
|
dtype=torch.int64).unsqueeze(1).expand(
|
|
-1, top_k).reshape(-1))
|
|
|
|
# Flatten token-to-expert mappings and map to local experts
|
|
weights_flat = topk_weights.view(-1)
|
|
experts_flat = topk_ids.view(-1)
|
|
local_experts_flat = expert_map[experts_flat]
|
|
|
|
# Filter valid token-expert pairs
|
|
mask = local_experts_flat != -1
|
|
filtered_weights = torch.where(
|
|
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
|
filtered_experts = torch.where(
|
|
mask, local_experts_flat,
|
|
torch.full_like(local_experts_flat,
|
|
num_experts)).to(topk_ids.dtype)
|
|
|
|
# Sort by local expert IDs
|
|
sort_indices = torch.argsort(filtered_experts)
|
|
sorted_token_indices = token_indices[sort_indices]
|
|
sorted_weights = filtered_weights[sort_indices]
|
|
|
|
# Compute token counts with minlength of num_experts
|
|
# This is equivalent to but faster than:
|
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
token_counts = torch.zeros(num_experts + 1,
|
|
device=device,
|
|
dtype=torch.int64)
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
token_counts = token_counts[:num_experts]
|
|
expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64)
|
|
|
|
# Rearrange hidden_states
|
|
sorted_hidden_states = hidden_states[sorted_token_indices]
|
|
else:
|
|
row_idx_len = num_tokens * top_k
|
|
row_idx = (torch.arange(0,
|
|
row_idx_len,
|
|
dtype=torch.int32,
|
|
device=device).view(top_k, -1).permute(
|
|
1, 0).contiguous())
|
|
sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
|
hidden_states,
|
|
row_idx=row_idx,
|
|
expert_idx=topk_ids,
|
|
active_num=num_tokens)
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
|
expanded_expert_idx, num_experts)
|
|
expert_tokens = expert_tokens.to(torch.int64)
|
|
|
|
w1 = w1.transpose(1, 2)
|
|
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
|
x=[sorted_hidden_states],
|
|
weight=[w1],
|
|
split_item=2,
|
|
group_list_type=0,
|
|
group_type=0,
|
|
group_list=expert_tokens,
|
|
)
|
|
|
|
# TODO: Remove this in the future.
|
|
gate_up_out = torch.cat(gate_up_out_list, dim=0)
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
|
|
|
w2 = w2.transpose(1, 2)
|
|
down_out_list = torch_npu.npu_grouped_matmul(
|
|
x=[gate_up_out],
|
|
weight=[w2],
|
|
split_item=2,
|
|
group_list_type=0,
|
|
group_type=0,
|
|
group_list=expert_tokens,
|
|
)
|
|
|
|
down_out_list = torch.cat(down_out_list, dim=0)
|
|
|
|
if expert_map is not None:
|
|
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
|
|
|
|
final_hidden_states = torch.zeros(*original_shape,
|
|
device=hidden_states.device,
|
|
dtype=dtype)
|
|
final_hidden_states.index_add_(0, sorted_token_indices,
|
|
weighted_down_out)
|
|
# TODO: This should not happen! Look into it!
|
|
# fill nan with 0.0
|
|
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
|
|
else:
|
|
# TODO: Reorder device memory 2 times here, replace the current
|
|
# implementation here when suitable operators become available.
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
down_out_list,
|
|
skip1=None,
|
|
skip2=None,
|
|
bias=None,
|
|
scales=topk_weights,
|
|
expanded_src_to_dst_row=expanded_row_idx,
|
|
export_for_source_row=topk_ids,
|
|
)
|
|
|
|
return final_hidden_states
|
|
|
|
|
|
def native_grouped_topk(
|
|
topk_weights: torch.Tensor,
|
|
num_expert_group: Optional[int],
|
|
topk_group: Optional[int],
|
|
):
|
|
topk_group = 0 if topk_group is None else topk_group
|
|
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
|
|
|
num_token = topk_weights.shape[0]
|
|
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
|
-1).max(dim=-1).values
|
|
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
|
k=topk_group,
|
|
dim=-1,
|
|
sorted=False)[1]
|
|
topk_group_mask = torch.zeros_like(grouped_weights)
|
|
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
|
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
|
num_token, num_expert_group,
|
|
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
|
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
|
|
|
return topk_weights
|
|
|
|
|
|
def select_experts(
|
|
hidden_states: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
use_grouped_topk: bool,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Select top-k experts based on router logits.
|
|
|
|
Args:
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
|
router_logits: Router logits of shape (num_tokens, num_experts).
|
|
top_k: Number of experts to select.
|
|
use_grouped_topk: Whether to group experts before selecting top-k.
|
|
renormalize: Whether to renormalize the routing weights.
|
|
topk_group: Number of expert groups to select from.
|
|
num_expert_group: Number of experts in each group.
|
|
custom_routing_function: Custom routing function.
|
|
scoring_func: Scoring function to use.
|
|
e_score_correction_bias: Correction bias to apply to expert scores.
|
|
|
|
Returns:
|
|
topk_weights: Routing weights of shape (num_tokens, top_k).
|
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
|
|
|
Raises:
|
|
ValueError: If an unsupported scoring function is provided.
|
|
"""
|
|
assert hidden_states.shape[0] == router_logits.shape[0], (
|
|
"Number of tokens mismatch")
|
|
|
|
if custom_routing_function is not None:
|
|
raise NotImplementedError(
|
|
"Custom routing function is not supported now")
|
|
|
|
if scoring_func == "softmax":
|
|
# NOTE: vLLM use dtype=torch.float here
|
|
topk_weights = router_logits.softmax(dim=-1)
|
|
elif scoring_func == "sigmoid":
|
|
topk_weights = router_logits.sigmoid()
|
|
else:
|
|
raise ValueError(f"Unsupported scoring function: {scoring_func}")
|
|
|
|
if use_grouped_topk:
|
|
assert topk_group is not None
|
|
assert num_expert_group is not None
|
|
|
|
if e_score_correction_bias is not None:
|
|
# Store original scores before applying correction bias. We use biased
|
|
# scores for expert selection but original scores for routing weights
|
|
original_weights = topk_weights
|
|
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)
|
|
|
|
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
|
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
|
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
|
|
topk_group)
|
|
|
|
if e_score_correction_bias is not None:
|
|
topk_ids = torch.topk(topk_weights, k=top_k, dim=-1,
|
|
sorted=False)[1]
|
|
# Use original unbiased scores for the routing weights
|
|
topk_weights = original_weights.gather(1, topk_ids)
|
|
else:
|
|
topk_weights, topk_ids = torch.topk(topk_weights,
|
|
k=top_k,
|
|
dim=-1,
|
|
sorted=False)
|
|
else:
|
|
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
|
|
topk_weights = topk_weights.to(hidden_states.dtype)
|
|
|
|
# Required by npu_moe_init_routing
|
|
topk_ids = topk_ids.to(torch.int32)
|
|
|
|
if renormalize:
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
|
|
return topk_weights, topk_ids
|
|
|
|
|
|
def forward_oot(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
use_grouped_topk: bool,
|
|
top_k: int,
|
|
router_logits: torch.Tensor,
|
|
renormalize: bool,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
**kwargs,
|
|
):
|
|
assert router_logits.shape[
|
|
1] == global_num_experts, "Number of global experts mismatch"
|
|
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
)
|
|
|
|
return fused_experts(hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
top_k=top_k,
|
|
expert_map=expert_map)
|
|
|
|
|
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|