[Bug] Fix usage of .transpose() and .view() consecutively. (#11979)

This commit is contained in:
Siyuan Li
2025-01-13 14:24:10 +08:00
committed by GitHub
parent f7b3ba82c3
commit 9dd02d85ca
2 changed files with 2 additions and 2 deletions

View File

@ -230,7 +230,7 @@ class MultiHeadAttention(nn.Module):
value,
scale=self.scale)
out = out.transpose(1, 2)
return out.view(bsz, q_len, -1)
return out.reshape(bsz, q_len, -1)
def unified_attention(

View File

@ -271,7 +271,7 @@ class InternSdpaAttention(nn.Module):
v = v.transpose(1, 2)
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
x = x.transpose(1, 2).view(B, N, -1)
x = x.transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
return x