mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Hardware][Gaudi][BugFix] fix arguments of hpu fused moe (#15945)
Signed-off-by: zhenwei <zhenweiliu@habana.ai>
This commit is contained in:
@ -254,9 +254,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
assert not use_grouped_topk
|
||||
assert num_expert_group is None
|
||||
@ -472,7 +475,7 @@ class FusedMoE(torch.nn.Module):
|
||||
"non-grouped topk.")
|
||||
if current_platform.is_hpu():
|
||||
from vllm_hpu_extension.ops import DynamicFusedMOE
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.num_experts)
|
||||
self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts)
|
||||
|
||||
# Note: get_quant_method will look at the layer's local_num_experts
|
||||
# for heuristic purposes, so it must be initialized first.
|
||||
|
Reference in New Issue
Block a user