Compare commits

...

1 Commits

Author SHA1 Message Date
b1292bca69 nits 2024-01-02 15:50:31 +01:00
3 changed files with 0 additions and 9 deletions

View File

@ -345,9 +345,6 @@ class LlamaAttention(nn.Module):
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,

View File

@ -233,9 +233,6 @@ class MistralAttention(nn.Module):
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,

View File

@ -281,9 +281,6 @@ class MixtralAttention(nn.Module):
base=self.rope_theta,
)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,