mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] Optimize the CPU overheads in FlashAttention custom op (#10733)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@ -135,6 +135,13 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
assert k_scale == 1.0 and v_scale == 1.0, (
|
||||
"key/v_scale is not supported in FlashAttention.")
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
# NOTE(woosuk): We do this outside the custom op to minimize the CPU
|
||||
# overheads from the non-CUDA-graph regions.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
output = torch.empty_like(query)
|
||||
torch.ops.vllm.unified_v1_flash_attention(
|
||||
output,
|
||||
@ -153,7 +160,7 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes,
|
||||
self.logits_soft_cap,
|
||||
)
|
||||
return output
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def unified_v1_flash_attention(
|
||||
@ -184,11 +191,6 @@ def unified_v1_flash_attention(
|
||||
attn_metadata: FlashAttentionMetadata = current_metadata
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, num_heads, head_size)
|
||||
key = key.view(-1, num_kv_heads, head_size)
|
||||
value = value.view(-1, num_kv_heads, head_size)
|
||||
|
||||
# Reshape the input keys and values and store them in the cache.
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
@ -218,8 +220,7 @@ def unified_v1_flash_attention(
|
||||
block_table=attn_metadata.block_table,
|
||||
softcap=logits_soft_cap,
|
||||
)
|
||||
attn_output = attn_output.view(num_actual_tokens, -1)
|
||||
# TODO(woosuk): Optimize this.
|
||||
# TODO(woosuk): Remove this unnecessary copy.
|
||||
output[:num_actual_tokens].copy_(attn_output)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user