[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)

This commit is contained in:
Sage Moore
2025-02-22 05:25:41 -08:00
committed by GitHub
parent e109e598c7
commit 558db8083c
6 changed files with 12 additions and 31 deletions

View File

@ -100,7 +100,7 @@ def test_contexted_kv_attention(
BS, max_block_per_request) BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long), dtype=torch.long),
dim=0) dim=0)
max_input_len = MAX_SEQ_LEN max_input_len = MAX_SEQ_LEN
@ -154,7 +154,6 @@ def test_contexted_kv_attention(
block_table, block_table,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
max_input_len, max_input_len,
k_scale, k_scale,
v_scale, v_scale,
@ -171,7 +170,6 @@ def test_contexted_kv_attention(
block_table, block_table,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
max_input_len, max_input_len,
k_scale, k_scale,
v_scale, v_scale,
@ -333,7 +331,7 @@ def test_contexted_kv_attention_alibi(
BS, max_block_per_request) BS, max_block_per_request)
b_seq_len = torch.tensor(seq_lens, dtype=torch.long) b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long) b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1], b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
dtype=torch.long), dtype=torch.long),
dim=0) dim=0)
max_input_len = MAX_SEQ_LEN max_input_len = MAX_SEQ_LEN
@ -387,7 +385,6 @@ def test_contexted_kv_attention_alibi(
block_table, block_table,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
max_input_len, max_input_len,
k_scale, k_scale,
v_scale, v_scale,
@ -404,7 +401,6 @@ def test_contexted_kv_attention_alibi(
block_table, block_table,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
max_input_len, max_input_len,
k_scale, k_scale,
v_scale, v_scale,

View File

@ -753,7 +753,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
prefill_meta.block_tables, prefill_meta.block_tables,
prefill_meta.query_start_loc, prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor, prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window[0], self.sliding_window[0],

View File

@ -580,7 +580,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
prefill_meta.block_tables, prefill_meta.block_tables,
prefill_meta.query_start_loc, prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor, prefill_meta.seq_lens_tensor,
prefill_meta.context_lens_tensor,
prefill_meta.max_query_len, prefill_meta.max_query_len,
self.alibi_slopes, self.alibi_slopes,
self.sliding_window, self.sliding_window,

View File

@ -202,7 +202,6 @@ class PagedAttention:
block_tables: torch.Tensor, block_tables: torch.Tensor,
query_start_loc: torch.Tensor, query_start_loc: torch.Tensor,
seq_lens_tensor: torch.Tensor, seq_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_query_len: int, max_query_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
sliding_window: Optional[int], sliding_window: Optional[int],
@ -220,9 +219,8 @@ class PagedAttention:
value_cache, value_cache,
block_tables, block_tables,
# query_start_loc is (batch_size + 1,) # query_start_loc is (batch_size + 1,)
query_start_loc[:-1], query_start_loc,
seq_lens_tensor, seq_lens_tensor,
context_lens,
max_query_len, max_query_len,
k_scale, k_scale,
v_scale, v_scale,

View File

@ -31,7 +31,6 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
B_Ctxlen,
block_size, block_size,
x, x,
Out, Out,
@ -72,10 +71,12 @@ if triton.__version__ >= "2.1.0":
cur_kv_head = cur_head // num_queries_per_kv cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
# start position inside of the query # start position inside of the query
# generally, N goes over kv, while M goes over query_len # generally, N goes over kv, while M goes over query_len
@ -466,7 +467,6 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
B_Ctxlen,
Alibi_slopes, Alibi_slopes,
block_size, block_size,
x, x,
@ -511,9 +511,12 @@ if triton.__version__ >= "2.1.0":
# cur_batch_seq_len: the length of prompts # cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix # cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0 # cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
cur_batch_query_len = (cur_batch_in_all_stop_index -
cur_batch_in_all_start_index)
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
block_start_loc = BLOCK_M * start_m block_start_loc = BLOCK_M * start_m
@ -713,7 +716,6 @@ if triton.__version__ >= "2.1.0":
b_loc, b_loc,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
max_input_len, max_input_len,
k_scale: torch.Tensor, k_scale: torch.Tensor,
v_scale: torch.Tensor, v_scale: torch.Tensor,
@ -765,6 +767,7 @@ if triton.__version__ >= "2.1.0":
batch, head = b_seq_len.shape[0], q.shape[1] batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1] num_queries_per_kv = q.shape[1] // k.shape[1]
assert batch + 1 == len(b_start_loc)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
# 0 means "disable" # 0 means "disable"
@ -784,7 +787,6 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
alibi_slopes, alibi_slopes,
v_cache.shape[3], v_cache.shape[3],
k_cache.shape[4], k_cache.shape[4],
@ -838,7 +840,6 @@ if triton.__version__ >= "2.1.0":
v_scale, v_scale,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
b_ctx_len,
v_cache.shape[3], v_cache.shape[3],
k_cache.shape[4], k_cache.shape[4],
o, o,

View File

@ -150,17 +150,6 @@ class ROCmAttentionImpl(AttentionImpl):
layer._v_scale, layer._v_scale,
) )
# TODO(sage): Refactor the context_attention_fwd kernel so that this
# overhead can be removed
context_lens = torch.empty_like(attn_metadata.seq_lens)
batch_size = len(attn_metadata.query_start_loc) - 1
assert len(context_lens) == batch_size
for i in range(batch_size):
query_start = attn_metadata.query_start_loc[i]
query_end = attn_metadata.query_start_loc[i + 1]
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
query_start)
# Compute attention and update output up to `num_actual_tokens`. # Compute attention and update output up to `num_actual_tokens`.
context_attention_fwd(q=query[:num_actual_tokens], context_attention_fwd(q=query[:num_actual_tokens],
k=key[:num_actual_tokens], k=key[:num_actual_tokens],
@ -172,7 +161,6 @@ class ROCmAttentionImpl(AttentionImpl):
b_loc=attn_metadata.block_table, b_loc=attn_metadata.block_table,
b_start_loc=attn_metadata.query_start_loc, b_start_loc=attn_metadata.query_start_loc,
b_seq_len=attn_metadata.seq_lens, b_seq_len=attn_metadata.seq_lens,
b_ctx_len=context_lens,
max_input_len=attn_metadata.max_query_len, max_input_len=attn_metadata.max_query_len,
k_scale=layer._k_scale, k_scale=layer._k_scale,
v_scale=layer._v_scale, v_scale=layer._v_scale,