mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
@ -286,37 +286,12 @@ class Ernie4_5_MoeStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoeRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.top_k = config.moe_k
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.norm_min = config.moe_norm_min
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, device_type: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
|
||||
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return router_logits, selected_experts, routing_weights
|
||||
|
||||
|
||||
class Ernie4_5_MoeExperts(nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.moe_num_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Ernie4_5_MoeMLP(config))
|
||||
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
|
||||
@ -349,7 +324,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def route_tokens_to_experts(self, hidden_states, router_logits):
|
||||
def route_tokens_to_experts(self, hidden_states):
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
@ -357,9 +332,9 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
|
||||
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
@ -367,20 +342,15 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return selected_experts, routing_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
hidden_states_reshaped = hidden_states.view(-1, self.hidden_dim)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states_reshaped)
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
router_logits = self.gate(hidden_states_reshaped.float())
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
@ -96,37 +96,12 @@ class Ernie4_5_MoeStatics(nn.Module):
|
||||
return hidden_states + self.e_score_correction_bias.squeeze()
|
||||
|
||||
|
||||
class Ernie4_5_MoeRouter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.top_k = config.moe_k
|
||||
self.num_experts = config.moe_num_experts
|
||||
self.norm_min = config.moe_norm_min
|
||||
self.gate = nn.Linear(config.hidden_size, config.moe_num_experts, bias=False, dtype=torch.float32)
|
||||
self.moe_statics = Ernie4_5_MoeStatics(config)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, device_type: str
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
|
||||
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
)
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return router_logits, selected_experts, routing_weights
|
||||
|
||||
|
||||
class Ernie4_5_MoeExperts(nn.ModuleList):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.num_experts = config.moe_num_experts
|
||||
for _ in range(self.num_experts):
|
||||
self.append(Ernie4_5_MoeMLP(config))
|
||||
self.append(Ernie4_5_MoeMLP(config, config.moe_intermediate_size))
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor
|
||||
@ -159,7 +134,7 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
if config.moe_num_shared_experts > 0:
|
||||
self.shared_experts = Ernie4_5_MoeMLP(config, config.moe_intermediate_size * config.moe_num_shared_experts)
|
||||
|
||||
def route_tokens_to_experts(self, hidden_states, router_logits):
|
||||
def route_tokens_to_experts(self, hidden_states):
|
||||
device_type = (
|
||||
hidden_states.device.type
|
||||
if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
|
||||
@ -167,9 +142,9 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
|
||||
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
||||
router_logits = self.gate(hidden_states.float())
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_bias = self.moe_statics.e_score_correction_bias.squeeze()
|
||||
_, selected_experts = torch.topk(routing_weights + routing_bias, self.top_k, dim=-1)
|
||||
_, selected_experts = torch.topk(self.moe_statics(routing_weights), self.top_k, dim=-1)
|
||||
routing_weights = torch.gather(routing_weights, dim=-1, index=selected_experts)
|
||||
routing_weights = routing_weights / torch.clamp(
|
||||
routing_weights.sum(dim=-1, keepdim=True), min=self.norm_min
|
||||
@ -177,20 +152,15 @@ class Ernie4_5_MoeSparseMoeBlock(nn.Module):
|
||||
routing_weights = routing_weights.to(router_logits.dtype)
|
||||
return selected_experts, routing_weights
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
hidden_states_reshaped = hidden_states.view(-1, self.hidden_dim)
|
||||
hidden_states = hidden_states.view(-1, self.hidden_dim)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states_reshaped)
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
router_logits = self.gate(hidden_states_reshaped.float())
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states_reshaped, router_logits)
|
||||
|
||||
final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights)
|
||||
selected_experts, routing_weights = self.route_tokens_to_experts(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, selected_experts, routing_weights)
|
||||
|
||||
if self.shared_experts is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
Reference in New Issue
Block a user