mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Attention] Allow V1 flash_attn to support cross-attention (#23297)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
@ -405,13 +405,6 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
|
||||
FlashAttentionBackend.validate_head_size(head_size)
|
||||
|
||||
if attn_type not in [
|
||||
AttentionType.DECODER, AttentionType.ENCODER_ONLY
|
||||
]:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"is not implemented for "
|
||||
"FlashAttentionImpl")
|
||||
|
||||
self.attn_type = attn_type
|
||||
self.vllm_flash_attn_version = get_flash_attn_version()
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype) \
|
||||
@ -477,7 +470,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Handle encoder attention differently - no KV cache needed
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, ):
|
||||
if attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER):
|
||||
# For encoder attention,
|
||||
# we use direct Q, K, V tensors without caching
|
||||
return self._forward_encoder_attention(query[:num_actual_tokens],
|
||||
@ -489,7 +482,11 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
# For decoder and cross-attention, use KV cache as before
|
||||
key_cache, value_cache = kv_cache.unbind(0)
|
||||
|
||||
if self.kv_sharing_target_layer_name is None:
|
||||
# key and value may be None in the case of cross attention. They are
|
||||
# calculated once based on the output from the encoder and then cached
|
||||
# in KV cache.
|
||||
if (self.kv_sharing_target_layer_name is None and key is not None
|
||||
and value is not None):
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
# Skip this if sharing KV cache with an earlier attention layer.
|
||||
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
|
||||
@ -528,7 +525,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
block_table = attn_metadata.block_table
|
||||
scheduler_metadata = attn_metadata.scheduler_metadata
|
||||
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
|
||||
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
|
||||
|
||||
flash_attn_varlen_func(
|
||||
q=query[:num_actual_tokens],
|
||||
|
Reference in New Issue
Block a user