[Models][QwenVL] Remove unnecessary .contiguous() calls (#27106)

Signed-off-by: Lukas Geiger <lukas.geiger94@gmail.com>
This commit is contained in:
Lukas Geiger
2025-10-18 16:05:05 +02:00
committed by GitHub
parent b26b70bec4
commit 5c2acb270a
2 changed files with 2 additions and 2 deletions

View File

@ -396,7 +396,7 @@ class Qwen2_5_VisionAttention(nn.Module):
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)

View File

@ -423,7 +423,7 @@ class Qwen2VisionAttention(nn.Module):
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]
q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v))
q, k, v = (rearrange(x, "s b ... -> b s ...") for x in (q, k, v))
if rotary_pos_emb is not None:
# [2 * b, s, heads, head_dim]
qk_concat = torch.cat([q, k], dim=0)