[Attention] Allow V1 flash_attn to support cross-attention (#23297)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant
2025-08-22 08:10:16 -04:00
committed by GitHub
parent 808d2e9aa0
commit 281710ef9a

View File

@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl):
FlashAttentionBackend.validate_head_size(head_size)
if attn_type not in [
AttentionType.DECODER, AttentionType.ENCODER_ONLY
]:
raise NotImplementedError("Encoder/decoder cross-attention "
"is not implemented for "
"FlashAttentionImpl")
self.attn_type = attn_type
self.vllm_flash_attn_version = get_flash_attn_version()
if is_quantized_kv_cache(self.kv_cache_dtype) \
@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens
# Handle encoder attention differently - no KV cache needed
if attn_type in (AttentionType.ENCODER_ONLY, ):
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
# For encoder attention,
# we use direct Q, K, V tensors without caching
return self._forward_encoder_attention(query[:num_actual_tokens],
@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached
# in KV cache.
if (self.kv_sharing_target_layer_name is None and key is not None
and value is not None):
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
flash_attn_varlen_func(
q=query[:num_actual_tokens],