[Bugfix] Fix compute_logits in Jamba (#6093)

This commit is contained in:
Roger Wang
2024-07-03 00:32:35 -07:00
committed by GitHub
parent f1c78138aa
commit 7cd2ebb025

View File

@ -876,7 +876,7 @@ class JambaForCausalLM(nn.Module):
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(self.lm_head.weight, hidden_states,
logits = self.logits_processor(self.lm_head, hidden_states,
sampling_metadata)
return logits