Double router compute? (#41653)

* weird double router compute?

* flip it
This commit is contained in:
Pablo Montalvo
2025-10-16 15:17:21 +02:00
committed by GitHub
parent 503c933f36
commit 9176af574a
2 changed files with 0 additions and 2 deletions

View File

@ -117,7 +117,6 @@ class DeepseekV2Moe(nn.Module):
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)

View File

@ -275,7 +275,6 @@ class DeepseekV2Moe(nn.Module):
residuals = hidden_states
orig_shape = hidden_states.shape
router_logits = nn.functional.linear(hidden_states.type(torch.float32), self.gate.weight.type(torch.float32))
router_logits = self.gate(hidden_states)
topk_indices, topk_weights = self.route_tokens_to_experts(router_logits)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
hidden_states = self.experts(hidden_states, topk_indices, topk_weights).view(*orig_shape)