[Ernie 4.5 Moe] Fix Moe and offloading (#41385)

fix
This commit is contained in:
Anton Vlasjuk
2025-10-16 13:59:01 +02:00
committed by GitHub
parent 44539827d5
commit baecdb8a97
2 changed files with 18 additions and 78 deletions

View File

@ -286,37 +286,12 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.moe_k
self.num_experts = config.moe_num_experts
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
def forward(
self, hidden_states: torch.Tensor, device_type: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return router_logits, selected_experts, routing_weights
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config))
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
@ -349,7 +324,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states, router_logits):
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
@ -357,9 +332,9 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
@ -367,20 +342,15 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(
self,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, self.hidden_dim)
hidden_states = hidden_states.view(-1, self.hidden_dim)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states_reshaped)
shared_output = self.shared_experts(hidden_states)
router_logits = self.gate(hidden_states_reshaped.float())
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output

View File

@ -96,37 +96,12 @@ class Ernie4_5_MoeStatics(nn.Module):
return hidden_states + self.e_score_correction_bias.squeeze()
class Ernie4_5_MoeRouter(nn.Module):
def __init__(self, config):
super().__init__()
self.top_k = config.moe_k
self.num_experts = config.moe_num_experts
self.norm_min = config.moe_norm_min
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
self.moe_statics = Ernie4_5_MoeStatics(config)
def forward(
self, hidden_states: torch.Tensor, device_type: str
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
)
routing_weights = routing_weights.to(router_logits.dtype)
return router_logits, selected_experts, routing_weights
class Ernie4_5_MoeExperts(nn.ModuleList):
def __init__(self, config):
super().__init__()
self.num_experts = config.moe_num_experts
for _ in range(self.num_experts):
self.append(Ernie4_5_MoeMLP(config))
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
def forward(
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
@ -159,7 +134,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
if config.moe_num_shared_experts > 0:
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
def route_tokens_to_experts(self, hidden_states, router_logits):
def route_tokens_to_experts(self, hidden_states):
device_type = (
hidden_states.device.type
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
@ -167,9 +142,9 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
)
with torch.autocast(device_type=device_type, enabled=False): # Force float32
router_logits = self.gate(hidden_states.float())
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
routing_weights = routing_weights / torch.clamp(
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
@ -177,20 +152,15 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
routing_weights = routing_weights.to(router_logits.dtype)
return selected_experts, routing_weights
def forward(
self,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, _ = hidden_states.shape
hidden_states_reshaped = hidden_states.view(-1, self.hidden_dim)
hidden_states = hidden_states.view(-1, self.hidden_dim)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states_reshaped)
shared_output = self.shared_experts(hidden_states)
router_logits = self.gate(hidden_states_reshaped.float())
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
if self.shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output