From 2a0309a646b1ed83a0c40974e08c8dc628726d3c Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sat, 25 Jan 2025 21:00:31 -0800 Subject: [PATCH] [Misc][Bugfix] FA3 support to ViT MHA layer (#12435) Signed-off-by: Roger Wang Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com> --- vllm/attention/layer.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index a90bb4fbf5..db682b4ac6 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -251,9 +251,28 @@ class MultiHeadAttention(nn.Module): _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1, }: - from vllm.vllm_flash_attn import flash_attn_func + from vllm.vllm_flash_attn import flash_attn_varlen_func - out = flash_attn_func(query, key, value, softmax_scale=self.scale) + cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, + step=q_len, + dtype=torch.int32, + device=query.device) + cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, + step=kv_len, + dtype=torch.int32, + device=key.device) + + out = flash_attn_varlen_func( + query.flatten(0, 1), + key.flatten(0, 1), + value.flatten(0, 1), + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=self.scale, + ) + out = out.reshape(bsz, q_len, -1) elif self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops