From 558db8083cfd9b7ee76eafdae32b237d951c8b10 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Sat, 22 Feb 2025 05:25:41 -0800 Subject: [PATCH] [V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095) --- tests/kernels/test_prefix_prefill.py | 8 ++------ vllm/attention/backends/rocm_flash_attn.py | 1 - vllm/attention/backends/xformers.py | 1 - vllm/attention/ops/paged_attn.py | 4 +--- vllm/attention/ops/prefix_prefill.py | 17 +++++++++-------- vllm/v1/attention/backends/rocm_attn.py | 12 ------------ 6 files changed, 12 insertions(+), 31 deletions(-) diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 2184c98525..c3ac6a37e7 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -100,7 +100,7 @@ def test_contexted_kv_attention( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -154,7 +154,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -171,7 +170,6 @@ def test_contexted_kv_attention( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -333,7 +331,7 @@ def test_contexted_kv_attention_alibi( BS, max_block_per_request) b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) - b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], + b_start_loc = torch.cumsum(torch.tensor([0] + query_lens, dtype=torch.long), dim=0) max_input_len = MAX_SEQ_LEN @@ -387,7 +385,6 @@ def test_contexted_kv_attention_alibi( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, @@ -404,7 +401,6 @@ def test_contexted_kv_attention_alibi( block_table, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale, v_scale, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index e1a8d3d336..1b1f6ca9be 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -753,7 +753,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window[0], diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 723a4558d0..ec8e1f2ee5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -580,7 +580,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]): prefill_meta.block_tables, prefill_meta.query_start_loc, prefill_meta.seq_lens_tensor, - prefill_meta.context_lens_tensor, prefill_meta.max_query_len, self.alibi_slopes, self.sliding_window, diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 2c60bd0c38..fd703413db 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -202,7 +202,6 @@ class PagedAttention: block_tables: torch.Tensor, query_start_loc: torch.Tensor, seq_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], @@ -220,9 +219,8 @@ class PagedAttention: value_cache, block_tables, # query_start_loc is (batch_size + 1,) - query_start_loc[:-1], + query_start_loc, seq_lens_tensor, - context_lens, max_query_len, k_scale, v_scale, diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 362c46a95f..103c408ebb 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -31,7 +31,6 @@ if triton.__version__ >= "2.1.0": v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, block_size, x, Out, @@ -72,10 +71,12 @@ if triton.__version__ >= "2.1.0": cur_kv_head = cur_head // num_queries_per_kv - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len # start position inside of the query # generally, N goes over kv, while M goes over query_len @@ -466,7 +467,6 @@ if triton.__version__ >= "2.1.0": v_scale, B_Start_Loc, B_Seqlen, - B_Ctxlen, Alibi_slopes, block_size, x, @@ -511,9 +511,12 @@ if triton.__version__ >= "2.1.0": # cur_batch_seq_len: the length of prompts # cur_batch_ctx_len: the length of prefix # cur_batch_in_all_start_index: the start id of the dim=0 - cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1) + cur_batch_query_len = (cur_batch_in_all_stop_index - + cur_batch_in_all_start_index) + cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len block_start_loc = BLOCK_M * start_m @@ -713,7 +716,6 @@ if triton.__version__ >= "2.1.0": b_loc, b_start_loc, b_seq_len, - b_ctx_len, max_input_len, k_scale: torch.Tensor, v_scale: torch.Tensor, @@ -765,6 +767,7 @@ if triton.__version__ >= "2.1.0": batch, head = b_seq_len.shape[0], q.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1] + assert batch + 1 == len(b_start_loc) grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, # 0 means "disable" @@ -784,7 +787,6 @@ if triton.__version__ >= "2.1.0": v_scale, b_start_loc, b_seq_len, - b_ctx_len, alibi_slopes, v_cache.shape[3], k_cache.shape[4], @@ -838,7 +840,6 @@ if triton.__version__ >= "2.1.0": v_scale, b_start_loc, b_seq_len, - b_ctx_len, v_cache.shape[3], k_cache.shape[4], o, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 5f3eb37514..0f3fabf05f 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -150,17 +150,6 @@ class ROCmAttentionImpl(AttentionImpl): layer._v_scale, ) - # TODO(sage): Refactor the context_attention_fwd kernel so that this - # overhead can be removed - context_lens = torch.empty_like(attn_metadata.seq_lens) - batch_size = len(attn_metadata.query_start_loc) - 1 - assert len(context_lens) == batch_size - for i in range(batch_size): - query_start = attn_metadata.query_start_loc[i] - query_end = attn_metadata.query_start_loc[i + 1] - context_lens[i] = attn_metadata.seq_lens[i] - (query_end - - query_start) - # Compute attention and update output up to `num_actual_tokens`. context_attention_fwd(q=query[:num_actual_tokens], k=key[:num_actual_tokens], @@ -172,7 +161,6 @@ class ROCmAttentionImpl(AttentionImpl): b_loc=attn_metadata.block_table, b_start_loc=attn_metadata.query_start_loc, b_seq_len=attn_metadata.seq_lens, - b_ctx_len=context_lens, max_input_len=attn_metadata.max_query_len, k_scale=layer._k_scale, v_scale=layer._v_scale,