mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix][Model] Fix Mllama SDPA illegal memory access for batched multi-image (#9626)
Signed-off-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@ -795,17 +795,19 @@ class MllamaTextCrossAttention(nn.Module):
|
||||
kv_len = k.shape[0]
|
||||
q = q.transpose(0, 1).view(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups, q_len,
|
||||
self.head_dim)
|
||||
self.head_dim).contiguous()
|
||||
k = k.transpose(0,
|
||||
1)[:,
|
||||
None, :, :].expand(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups,
|
||||
kv_len, self.head_dim)
|
||||
kv_len,
|
||||
self.head_dim).contiguous()
|
||||
v = v.transpose(0,
|
||||
1)[:,
|
||||
None, :, :].expand(self.num_local_key_value_heads,
|
||||
self.num_key_value_groups,
|
||||
kv_len, self.head_dim)
|
||||
kv_len,
|
||||
self.head_dim).contiguous()
|
||||
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
|
||||
output = F.scaled_dot_product_attention(q,
|
||||
k,
|
||||
|
Reference in New Issue
Block a user