[BugFix] Register expert_map as named buffer for wake_up and sleep (#25458)

Signed-off-by: wuxibin <wuxibin@bytedance.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Joel
2025-09-23 21:49:13 +08:00
committed by GitHub
parent b6a136b58c
commit 61d1b35561

View File

@ -972,12 +972,15 @@ class FusedMoE(CustomOp):
"experts. Falling back to linear expert placement.")
expert_placement_strategy = "linear"
self.local_num_experts, self.expert_map = determine_expert_map(
self.expert_map: Optional[torch.Tensor]
local_num_experts, expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
@ -1154,10 +1157,12 @@ class FusedMoE(CustomOp):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
with self.expert_map.device:
self.local_num_experts, self.expert_map = determine_expert_map(
local_num_experts, expert_map = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
def _load_per_tensor_weight_scale(self, shard_id: str,
param: torch.nn.Parameter,