mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
This PR adds support for redundant experts in the EPLB. Key points: - Use global_num_experts = num_experts + num_redundant_experts consistently. - Backward compatible when num_redundant_experts=0. Tested On a 16-rank setup (W8A8) with static EPLB and expert_map_path, verifying router logits shape and successful requests. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: yechao237 <yechao20180411@gmail.com>
598 lines
25 KiB
Python
598 lines
25 KiB
Python
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
from typing import List, TypedDict
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch_npu
|
|
from pytest_mock import MockerFixture
|
|
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
|
|
|
|
from tests.ut.base import TestBase
|
|
from vllm_ascend.ascend_forward_context import MoECommType
|
|
from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod
|
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
|
from vllm_ascend.ops.moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
|
from vllm_ascend.utils import AscendSocVersion, adapt_patch
|
|
|
|
adapt_patch(True)
|
|
|
|
|
|
def mock_ep_and_mc2_group(mocker):
|
|
mock_group = mocker.MagicMock()
|
|
mock_group.rank_in_group = 0
|
|
mock_group.rank = 0
|
|
mock_group.world_size = 4
|
|
mock_group.device_group = "mock_group_ep"
|
|
mock_group.all_to_all = MagicMock(return_value=torch.randn(8, 8))
|
|
return mock_group
|
|
|
|
|
|
def mock_dp_and_tp_group(mocker):
|
|
mock_group = mocker.MagicMock()
|
|
mock_group.rank_in_group = 0
|
|
mock_group.world_size = 2
|
|
mock_group.device_group = "mock_group"
|
|
mock_group.all_gather = MagicMock(return_value=torch.randn(10, 32))
|
|
return mock_group
|
|
|
|
|
|
def mock_npu_format_cast(weight_data, format):
|
|
return weight_data
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup_vllm_config_mock(mocker: MockerFixture):
|
|
mock_hf_config = MagicMock()
|
|
mock_hf_config.model_type = "llama"
|
|
|
|
mock_model_config = MagicMock()
|
|
mock_model_config.hf_config = mock_hf_config
|
|
|
|
mock_vllm_config = MagicMock()
|
|
mock_vllm_config.model_config = mock_model_config
|
|
mock_vllm_config.parallel_config = MagicMock(tensor_parallel_size=2)
|
|
mock_vllm_config.scheduler_config = MagicMock(max_num_seqs=4)
|
|
mock_vllm_config.model_config.max_model_len = 2048
|
|
|
|
mocker.patch('vllm_ascend.ops.common_fused_moe.get_current_vllm_config',
|
|
return_value=mock_vllm_config)
|
|
mocker.patch('vllm_ascend.ops.moe.moe_comm_method.get_current_vllm_config',
|
|
return_value=mock_vllm_config)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_dist_env(mocker: MockerFixture):
|
|
mock_moe_comm_method = MagicMock()
|
|
|
|
def mock_prepare(hidden_states, router_logits, **kwargs):
|
|
return hidden_states, router_logits
|
|
|
|
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
|
|
|
mock_fused_experts_result = torch.randn(16, 2)
|
|
mock_moe_comm_method.fused_experts.return_value = mock_fused_experts_result
|
|
|
|
def mock_finalize(hidden_states, **kwargs):
|
|
return hidden_states
|
|
|
|
mock_moe_comm_method.finalize.side_effect = mock_finalize
|
|
dp_metadata = MagicMock(num_tokens_across_dp_cpu=[5, 5])
|
|
mock_weight_prefetch_method = MagicMock()
|
|
mock_forward_context_obj = MagicMock(
|
|
moe_comm_method=mock_moe_comm_method,
|
|
moe_comm_type=MoECommType.MC2,
|
|
max_tokens_across_dp=10,
|
|
dp_metadata=dp_metadata,
|
|
mc2_mask=torch.zeros(16, dtype=torch.bool),
|
|
padded_num_tokens=16,
|
|
with_quant=False,
|
|
weight_prefetch_method=mock_weight_prefetch_method)
|
|
|
|
with patch('torch.distributed.get_rank', return_value=0), \
|
|
patch('torch.distributed.get_world_size', return_value=4), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.moe.token_dispatcher.get_ep_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_mc2_group', return_value=mock_ep_and_mc2_group(mocker)), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm.model_executor.layers.fused_moe.config.get_dp_group',
|
|
return_value=mock_dp_and_tp_group(mocker)), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_ascend_config',
|
|
return_value=MagicMock(
|
|
torchair_graph_config=MagicMock(enabled=False),
|
|
enable_multistream_moe=False,
|
|
expert_map_path=None
|
|
)), \
|
|
patch('vllm_ascend.ops.common_fused_moe.determine_expert_map',
|
|
return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \
|
|
patch('vllm_ascend.ops.common_fused_moe.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch('vllm_ascend.ops.moe.fused_moe_prepare_and_finalize.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \
|
|
patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context',
|
|
return_value=mock_forward_context_obj), \
|
|
patch('vllm_ascend.ops.moe.moe_comm_method.MC2CommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.moe.moe_comm_method.AlltoAllCommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.moe.moe_comm_method.AllGatherCommImpl._get_token_dispatcher',
|
|
return_value=None), \
|
|
patch('vllm_ascend.ops.moe.experts_selector.get_forward_context',
|
|
return_value=mock_forward_context_obj):
|
|
|
|
yield {
|
|
'mock_forward_context_obj': mock_forward_context_obj,
|
|
'mock_moe_comm_method': mock_moe_comm_method,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_moe_env(mocker: MockerFixture):
|
|
|
|
with patch('torch_npu.npu_moe_gating_top_k', return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
None
|
|
)), \
|
|
patch('torch_npu.npu_moe_init_routing', return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
|
)), \
|
|
patch("torch_npu.npu_moe_compute_expert_tokens", return_value=(
|
|
torch.randn(8, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_distribute_dispatch", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_distribute_combine", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_grouped_matmul", return_value=(
|
|
[torch.randn(16, 2)]
|
|
)), \
|
|
patch("torch_npu.npu_swiglu", return_value=(
|
|
torch.randn(16, 2)
|
|
)), \
|
|
patch("torch_npu.npu_moe_gating_top_k_softmax", return_value=(
|
|
torch.randn(8, 2),
|
|
torch.randint(0, 8, (8, 2)),
|
|
torch.tensor([0, 1, 2, 4, 6, 2, 7, 1])
|
|
)), \
|
|
patch("torch_npu.npu_moe_finalize_routing", return_value=(
|
|
torch.randn(16, 2)
|
|
)):
|
|
if hasattr(torch_npu, 'npu_moe_distribute_dispatch_v2'):
|
|
with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(
|
|
torch.randn(16, 2))), \
|
|
patch("torch_npu.npu_moe_distribute_combine_v2", return_value=(
|
|
torch.randn(16, 2))):
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def default_moe_config():
|
|
return {
|
|
'num_experts': 8,
|
|
'top_k': 2,
|
|
'hidden_size': 512,
|
|
'intermediate_size': 1024
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
def moe_method(mock_dist_env):
|
|
moe = MagicMock()
|
|
moe.moe_parallel_config.return_value = MagicMock(ep_size=4)
|
|
moe.moe_parallel_config.use_ep = False
|
|
moe.moe_parallel_config.dp_size = 1
|
|
return AscendUnquantizedFusedMoEMethod(moe)
|
|
|
|
|
|
class Device(TypedDict):
|
|
device_id: int
|
|
device_expert: List[int]
|
|
|
|
|
|
class Layer(TypedDict):
|
|
layer_id: int
|
|
device_count: int
|
|
device_list: List[Device]
|
|
|
|
|
|
class MockData(TypedDict):
|
|
moe_layer_count: int
|
|
layer_list: List[Layer]
|
|
|
|
|
|
class MockQuantMethod(nn.Module):
|
|
|
|
def __init__(self, shared_experts, num_tokens):
|
|
super().__init__()
|
|
if shared_experts:
|
|
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32),
|
|
torch.randn(num_tokens, 10)))
|
|
else:
|
|
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))
|
|
|
|
|
|
class MockFusedMoEMethod(FusedMoEMethodBase):
|
|
moe = MagicMock()
|
|
|
|
def __init__(self):
|
|
super().__init__(self.moe)
|
|
|
|
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
|
hidden_size: int, intermediate_size_per_partition: int,
|
|
params_dtype: torch.dtype, **extra_weight_attrs):
|
|
pass
|
|
|
|
def apply(self, hidden_states: torch.Tensor,
|
|
expert_weights: torch.Tensor) -> torch.Tensor:
|
|
pass
|
|
|
|
def get_fused_moe_quant_config(self, layer: torch.nn.Module):
|
|
pass
|
|
|
|
|
|
class TestExpertsSelector:
|
|
|
|
@pytest.mark.parametrize("global_num_experts", [256, 128])
|
|
def test_select_experts(self, mock_dist_env, mock_moe_env,
|
|
global_num_experts):
|
|
|
|
x = torch.randn(8, 2)
|
|
router_logits = torch.randn(8, 2)
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=2,
|
|
use_grouped_topk=False,
|
|
renormalize=True,
|
|
topk_group=None,
|
|
num_expert_group=None,
|
|
custom_routing_function=None,
|
|
scoring_func="softmax",
|
|
e_score_correction_bias=None,
|
|
global_num_experts=global_num_experts)
|
|
|
|
assert topk_weights.shape == (8, 2)
|
|
assert topk_ids.shape == (8, 2)
|
|
|
|
|
|
class TestCumsumGroupList(TestBase):
|
|
|
|
def setUp(self):
|
|
self.active_num = 8
|
|
self.expert_num = 128
|
|
self.experts = torch.zeros((self.expert_num, ), dtype=torch.int64)
|
|
self.experts[:self.active_num] = 1
|
|
self.experts = self.experts[torch.randperm(self.expert_num)]
|
|
self.group_list = self.experts.cumsum(dim=0)
|
|
|
|
def test_cumsum_group_list_with_type_0(self):
|
|
group_list = self.experts.cumsum(dim=0)
|
|
group_list_type = 0
|
|
result = cumsum_group_list(group_list, group_list_type)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
def test_cumsum_group_list_with_type_1(self):
|
|
group_list = self.experts
|
|
group_list_type = 1
|
|
result = cumsum_group_list(group_list, group_list_type)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
def test_cumsum_group_list_with_type_2(self):
|
|
tokens = torch.arange(self.expert_num, dtype=torch.int64)
|
|
group_list = torch.cat([
|
|
tokens.reshape(self.expert_num, 1),
|
|
self.experts.reshape(self.expert_num, 1)
|
|
],
|
|
dim=1)
|
|
group_list_type = 2
|
|
result = cumsum_group_list(group_list,
|
|
group_list_type,
|
|
active_num=self.active_num,
|
|
expert_num=self.expert_num)
|
|
self.assertTrue(torch.equal(result, self.group_list))
|
|
|
|
|
|
class TestUnifiedApplyMLP(TestBase):
|
|
|
|
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
|
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
@patch('torch_npu.npu_dequant_swiglu_quant')
|
|
def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant,
|
|
mock_npu_dynamic_quant,
|
|
mock_npu_grouped_matmul,
|
|
mock_is_310p,
|
|
mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.moe_comm_type = MoECommType.MC2
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_is_310p.return_value = False
|
|
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 20),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32)
|
|
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
|
|
|
mock_npu_dequant.return_value = (torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16),
|
|
torch.randn(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.float32)
|
|
w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=None,
|
|
with_quant=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
|
|
mock_npu_dynamic_quant.assert_called()
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
|
|
mock_npu_dequant.assert_called_once()
|
|
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|
|
|
|
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_without_quantization(self,
|
|
mock_npu_dynamic_quant,
|
|
mock_npu_swiglu,
|
|
mock_npu_grouped_matmul,
|
|
mock_is_310p):
|
|
mock_is_310p.return_value = False
|
|
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 40, dtype=torch.float16)
|
|
], [torch.randn(10, 20, dtype=torch.float16)]]
|
|
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
|
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=None,
|
|
w2=w2,
|
|
w2_scale=None,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=topk_scales,
|
|
with_quant=False)
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@patch('vllm_ascend.ops.moe.moe_mlp.get_forward_context')
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_with_quantization_and_dynamic_scale(
|
|
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
|
mock_npu_grouped_matmul, mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.with_quant = True
|
|
mock_forward_context.fused_moe_state = "NOT_MC2"
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 40, dtype=torch.bfloat16)
|
|
], [torch.randn(10, 20, dtype=torch.bfloat16)]]
|
|
|
|
mock_npu_swiglu.return_value = torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16)
|
|
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 40),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=provided_dynamic_scale,
|
|
group_list_type=1,
|
|
w1_scale_bias=w1_scale_bias,
|
|
w2_scale_bias=w2_scale_bias,
|
|
topk_scales=None,
|
|
with_quant=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
mock_npu_dynamic_quant.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|
|
|
|
@patch('vllm_ascend.ops.moe.moe_mlp.is_310p')
|
|
@patch('torch_npu.npu_grouped_matmul')
|
|
@patch('torch_npu.npu_swiglu')
|
|
@patch('torch_npu.npu_dynamic_quant')
|
|
def test_unified_apply_mlp_without_quantization_310p(
|
|
self, mock_npu_dynamic_quant, mock_npu_swiglu,
|
|
mock_npu_grouped_matmul, mock_is_310p):
|
|
mock_is_310p.return_value = True
|
|
|
|
mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16)
|
|
mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16)
|
|
mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out],
|
|
[mock_gmm2_out]]
|
|
|
|
mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16)
|
|
|
|
mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock())
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.float16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.float16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.float16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=None,
|
|
w2=w2,
|
|
w2_scale=None,
|
|
group_list=group_list,
|
|
dynamic_scale=None,
|
|
group_list_type=1,
|
|
w1_scale_bias=None,
|
|
w2_scale_bias=None,
|
|
topk_scales=topk_scales,
|
|
with_quant=False)
|
|
|
|
mock_is_310p.assert_called_once()
|
|
|
|
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
|
mock_npu_swiglu.assert_called_once()
|
|
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.float16)
|
|
|
|
@patch("vllm_ascend.ops.moe.moe_mlp.get_forward_context")
|
|
@patch("torch_npu.npu_grouped_matmul")
|
|
@patch("torch_npu.npu_swiglu")
|
|
@patch("torch_npu.npu_grouped_matmul_swiglu_quant")
|
|
@patch("torch_npu.npu_dynamic_quant")
|
|
def test_unified_apply_mlp_with_quantization_and_fusion_mlp(
|
|
self, mock_npu_dynamic_quant, mock_npu_grouped_matmul_swiglu_quant,
|
|
mock_npu_swiglu, mock_npu_grouped_matmul,
|
|
mock_get_forward_context):
|
|
|
|
mock_forward_context = MagicMock()
|
|
mock_forward_context.with_quant = True
|
|
mock_forward_context.fused_moe_state = "NOT_MC2"
|
|
mock_get_forward_context.return_value = mock_forward_context
|
|
|
|
mock_npu_grouped_matmul_swiglu_quant.return_value = (torch.randint(
|
|
-128, 127, (10, 40),
|
|
dtype=torch.int8), torch.rand(
|
|
10, 1,
|
|
dtype=torch.float32), torch.rand(10, 1, dtype=torch.float32))
|
|
mock_npu_grouped_matmul.side_effect = [[
|
|
torch.randn(10, 20, dtype=torch.bfloat16)
|
|
]]
|
|
mock_npu_swiglu.return_value = torch.randn(10,
|
|
40,
|
|
dtype=torch.bfloat16)
|
|
mock_npu_dynamic_quant.return_value = (torch.randint(-128,
|
|
127, (10, 40),
|
|
dtype=torch.int8),
|
|
torch.rand(10,
|
|
1,
|
|
dtype=torch.float32))
|
|
|
|
hidden_states = torch.randn(10, 20, dtype=torch.bfloat16)
|
|
w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16)
|
|
w1_scale = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16)
|
|
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16)
|
|
w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16)
|
|
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
|
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
|
|
|
result = unified_apply_mlp(hidden_states=hidden_states,
|
|
w1=w1,
|
|
w1_scale=w1_scale,
|
|
w2=w2,
|
|
w2_scale=w2_scale,
|
|
group_list=group_list,
|
|
dynamic_scale=provided_dynamic_scale,
|
|
group_list_type=1,
|
|
w1_scale_bias=w1_scale_bias,
|
|
w2_scale_bias=w2_scale_bias,
|
|
topk_scales=None,
|
|
with_quant=True,
|
|
fusion=True)
|
|
|
|
mock_get_forward_context.assert_called()
|
|
mock_npu_grouped_matmul.assert_called_once()
|
|
mock_npu_grouped_matmul_swiglu_quant.assert_called_once()
|
|
|
|
self.assertTrue(mock_forward_context.with_quant)
|
|
self.assertEqual(result.shape, hidden_states.shape)
|
|
self.assertEqual(result.dtype, torch.bfloat16)
|