mirror of
https://github.com/huggingface/transformers.git
synced 2025-10-20 09:03:53 +08:00
Double router compute? (#41653)
* weird double router compute? * flip it
This commit is contained in:
@ -117,7 +117,6 @@ class DeepseekV2Moe(nn.Module):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
|
@ -275,7 +275,6 @@ class DeepseekV2Moe(nn.Module):
|
||||
residuals = hidden_states
|
||||
orig_shape = hidden_states.shape
|
||||
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)
|
||||
|
Reference in New Issue
Block a user