[V1] Optimize the CPU overheads in FlashAttention custom op (#10733)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2024-11-28 09:01:02 -08:00
committed by GitHub
parent 8c1e77fb58
commit 98f47f2a40

View File

@ -135,6 +135,13 @@ class FlashAttentionImpl(AttentionImpl):
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashAttention.")
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
# overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)
output = torch.empty_like(query)
torch.ops.vllm.unified_v1_flash_attention(
output,
@ -153,7 +160,7 @@ class FlashAttentionImpl(AttentionImpl):
self.alibi_slopes,
self.logits_soft_cap,
)
return output
return output.view(-1, self.num_heads * self.head_size)
def unified_v1_flash_attention(
@ -184,11 +191,6 @@ def unified_v1_flash_attention(
attn_metadata: FlashAttentionMetadata = current_metadata
num_actual_tokens = attn_metadata.num_actual_tokens
# Reshape the query, key, and value tensors.
query = query.view(-1, num_heads, head_size)
key = key.view(-1, num_kv_heads, head_size)
value = value.view(-1, num_kv_heads, head_size)
# Reshape the input keys and values and store them in the cache.
key_cache = kv_cache[0]
value_cache = kv_cache[1]
@ -218,8 +220,7 @@ def unified_v1_flash_attention(
block_table=attn_metadata.block_table,
softcap=logits_soft_cap,
)
attn_output = attn_output.view(num_actual_tokens, -1)
# TODO(woosuk): Optimize this.
# TODO(woosuk): Remove this unnecessary copy.
output[:num_actual_tokens].copy_(attn_output)