mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Register expert_map as named buffer for wake_up and sleep (#25458)
Signed-off-by: wuxibin <wuxibin@bytedance.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -972,12 +972,15 @@ class FusedMoE(CustomOp):
|
||||
"experts. Falling back to linear expert placement.")
|
||||
expert_placement_strategy = "linear"
|
||||
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.expert_map: Optional[torch.Tensor]
|
||||
local_num_experts, expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
logger.info_once(
|
||||
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
|
||||
"placement strategy: %s. Local/global"
|
||||
@ -1154,10 +1157,12 @@ class FusedMoE(CustomOp):
|
||||
# ep_size and ep_rank should already be updated
|
||||
assert self.expert_map is not None
|
||||
with self.expert_map.device:
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
local_num_experts, expert_map = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
|
||||
def _load_per_tensor_weight_scale(self, shard_id: str,
|
||||
param: torch.nn.Parameter,
|
||||
|
Reference in New Issue
Block a user