[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626)

Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
Michael Goin
2024-10-23 22:03:44 -04:00
committed by GitHub
parent b548d7a5f4
commit bb01f2915e

View File

@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
self.num_key_value_groups, q_len,
self.head_dim)
self.head_dim).contiguous()
k = k.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
kv_len,
self.head_dim).contiguous()
v = v.transpose(0,
1)[:,
None, :, :].expand(self.num_local_key_value_heads,
self.num_key_value_groups,
kv_len, self.head_dim)
kv_len,
self.head_dim).contiguous()
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
output = F.scaled_dot_product_attention(q,
k,