mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 17:13:56 +08:00
Double router compute? (#41653)
* weird double router compute? * flip it
This commit is contained in:
@ -117,7 +117,6 @@ class DeepseekV2Moe(nn.Module):
|
|||||||
residuals = hidden_states
|
residuals = hidden_states
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
||||||
router_logits = self.gate(hidden_states)
|
|
||||||
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||||
|
@ -275,7 +275,6 @@ class DeepseekV2Moe(nn.Module):
|
|||||||
residuals = hidden_states
|
residuals = hidden_states
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
||||||
router_logits = self.gate(hidden_states)
|
|
||||||
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
||||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||||
|
Reference in New Issue
Block a user