mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)
This commit is contained in:
@ -100,7 +100,7 @@ def test_contexted_kv_attention(
|
|||||||
BS, max_block_per_request)
|
BS, max_block_per_request)
|
||||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||||
dtype=torch.long),
|
dtype=torch.long),
|
||||||
dim=0)
|
dim=0)
|
||||||
max_input_len = MAX_SEQ_LEN
|
max_input_len = MAX_SEQ_LEN
|
||||||
@ -154,7 +154,6 @@ def test_contexted_kv_attention(
|
|||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
@ -171,7 +170,6 @@ def test_contexted_kv_attention(
|
|||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
@ -333,7 +331,7 @@ def test_contexted_kv_attention_alibi(
|
|||||||
BS, max_block_per_request)
|
BS, max_block_per_request)
|
||||||
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
|
||||||
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
|
||||||
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
|
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
|
||||||
dtype=torch.long),
|
dtype=torch.long),
|
||||||
dim=0)
|
dim=0)
|
||||||
max_input_len = MAX_SEQ_LEN
|
max_input_len = MAX_SEQ_LEN
|
||||||
@ -387,7 +385,6 @@ def test_contexted_kv_attention_alibi(
|
|||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
@ -404,7 +401,6 @@ def test_contexted_kv_attention_alibi(
|
|||||||
block_table,
|
block_table,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
|
@ -753,7 +753,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
|
|||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.query_start_loc,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.context_lens_tensor,
|
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window[0],
|
self.sliding_window[0],
|
||||||
|
@ -580,7 +580,6 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
|
|||||||
prefill_meta.block_tables,
|
prefill_meta.block_tables,
|
||||||
prefill_meta.query_start_loc,
|
prefill_meta.query_start_loc,
|
||||||
prefill_meta.seq_lens_tensor,
|
prefill_meta.seq_lens_tensor,
|
||||||
prefill_meta.context_lens_tensor,
|
|
||||||
prefill_meta.max_query_len,
|
prefill_meta.max_query_len,
|
||||||
self.alibi_slopes,
|
self.alibi_slopes,
|
||||||
self.sliding_window,
|
self.sliding_window,
|
||||||
|
@ -202,7 +202,6 @@ class PagedAttention:
|
|||||||
block_tables: torch.Tensor,
|
block_tables: torch.Tensor,
|
||||||
query_start_loc: torch.Tensor,
|
query_start_loc: torch.Tensor,
|
||||||
seq_lens_tensor: torch.Tensor,
|
seq_lens_tensor: torch.Tensor,
|
||||||
context_lens: torch.Tensor,
|
|
||||||
max_query_len: int,
|
max_query_len: int,
|
||||||
alibi_slopes: Optional[torch.Tensor],
|
alibi_slopes: Optional[torch.Tensor],
|
||||||
sliding_window: Optional[int],
|
sliding_window: Optional[int],
|
||||||
@ -220,9 +219,8 @@ class PagedAttention:
|
|||||||
value_cache,
|
value_cache,
|
||||||
block_tables,
|
block_tables,
|
||||||
# query_start_loc is (batch_size + 1,)
|
# query_start_loc is (batch_size + 1,)
|
||||||
query_start_loc[:-1],
|
query_start_loc,
|
||||||
seq_lens_tensor,
|
seq_lens_tensor,
|
||||||
context_lens,
|
|
||||||
max_query_len,
|
max_query_len,
|
||||||
k_scale,
|
k_scale,
|
||||||
v_scale,
|
v_scale,
|
||||||
|
@ -31,7 +31,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_scale,
|
v_scale,
|
||||||
B_Start_Loc,
|
B_Start_Loc,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
B_Ctxlen,
|
|
||||||
block_size,
|
block_size,
|
||||||
x,
|
x,
|
||||||
Out,
|
Out,
|
||||||
@ -72,10 +71,12 @@ if triton.__version__ >= "2.1.0":
|
|||||||
|
|
||||||
cur_kv_head = cur_head // num_queries_per_kv
|
cur_kv_head = cur_head // num_queries_per_kv
|
||||||
|
|
||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||||
|
cur_batch_in_all_start_index)
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
# start position inside of the query
|
# start position inside of the query
|
||||||
# generally, N goes over kv, while M goes over query_len
|
# generally, N goes over kv, while M goes over query_len
|
||||||
@ -466,7 +467,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_scale,
|
v_scale,
|
||||||
B_Start_Loc,
|
B_Start_Loc,
|
||||||
B_Seqlen,
|
B_Seqlen,
|
||||||
B_Ctxlen,
|
|
||||||
Alibi_slopes,
|
Alibi_slopes,
|
||||||
block_size,
|
block_size,
|
||||||
x,
|
x,
|
||||||
@ -511,9 +511,12 @@ if triton.__version__ >= "2.1.0":
|
|||||||
# cur_batch_seq_len: the length of prompts
|
# cur_batch_seq_len: the length of prompts
|
||||||
# cur_batch_ctx_len: the length of prefix
|
# cur_batch_ctx_len: the length of prefix
|
||||||
# cur_batch_in_all_start_index: the start id of the dim=0
|
# cur_batch_in_all_start_index: the start id of the dim=0
|
||||||
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
|
|
||||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||||
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
|
||||||
|
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
|
||||||
|
cur_batch_query_len = (cur_batch_in_all_stop_index -
|
||||||
|
cur_batch_in_all_start_index)
|
||||||
|
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
|
||||||
|
|
||||||
block_start_loc = BLOCK_M * start_m
|
block_start_loc = BLOCK_M * start_m
|
||||||
|
|
||||||
@ -713,7 +716,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
b_loc,
|
b_loc,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
max_input_len,
|
max_input_len,
|
||||||
k_scale: torch.Tensor,
|
k_scale: torch.Tensor,
|
||||||
v_scale: torch.Tensor,
|
v_scale: torch.Tensor,
|
||||||
@ -765,6 +767,7 @@ if triton.__version__ >= "2.1.0":
|
|||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
num_queries_per_kv = q.shape[1] // k.shape[1]
|
num_queries_per_kv = q.shape[1] // k.shape[1]
|
||||||
|
|
||||||
|
assert batch + 1 == len(b_start_loc)
|
||||||
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
|
||||||
|
|
||||||
# 0 means "disable"
|
# 0 means "disable"
|
||||||
@ -784,7 +787,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_scale,
|
v_scale,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
alibi_slopes,
|
alibi_slopes,
|
||||||
v_cache.shape[3],
|
v_cache.shape[3],
|
||||||
k_cache.shape[4],
|
k_cache.shape[4],
|
||||||
@ -838,7 +840,6 @@ if triton.__version__ >= "2.1.0":
|
|||||||
v_scale,
|
v_scale,
|
||||||
b_start_loc,
|
b_start_loc,
|
||||||
b_seq_len,
|
b_seq_len,
|
||||||
b_ctx_len,
|
|
||||||
v_cache.shape[3],
|
v_cache.shape[3],
|
||||||
k_cache.shape[4],
|
k_cache.shape[4],
|
||||||
o,
|
o,
|
||||||
|
@ -150,17 +150,6 @@ class ROCmAttentionImpl(AttentionImpl):
|
|||||||
layer._v_scale,
|
layer._v_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO(sage): Refactor the context_attention_fwd kernel so that this
|
|
||||||
# overhead can be removed
|
|
||||||
context_lens = torch.empty_like(attn_metadata.seq_lens)
|
|
||||||
batch_size = len(attn_metadata.query_start_loc) - 1
|
|
||||||
assert len(context_lens) == batch_size
|
|
||||||
for i in range(batch_size):
|
|
||||||
query_start = attn_metadata.query_start_loc[i]
|
|
||||||
query_end = attn_metadata.query_start_loc[i + 1]
|
|
||||||
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
|
|
||||||
query_start)
|
|
||||||
|
|
||||||
# Compute attention and update output up to `num_actual_tokens`.
|
# Compute attention and update output up to `num_actual_tokens`.
|
||||||
context_attention_fwd(q=query[:num_actual_tokens],
|
context_attention_fwd(q=query[:num_actual_tokens],
|
||||||
k=key[:num_actual_tokens],
|
k=key[:num_actual_tokens],
|
||||||
@ -172,7 +161,6 @@ class ROCmAttentionImpl(AttentionImpl):
|
|||||||
b_loc=attn_metadata.block_table,
|
b_loc=attn_metadata.block_table,
|
||||||
b_start_loc=attn_metadata.query_start_loc,
|
b_start_loc=attn_metadata.query_start_loc,
|
||||||
b_seq_len=attn_metadata.seq_lens,
|
b_seq_len=attn_metadata.seq_lens,
|
||||||
b_ctx_len=context_lens,
|
|
||||||
max_input_len=attn_metadata.max_query_len,
|
max_input_len=attn_metadata.max_query_len,
|
||||||
k_scale=layer._k_scale,
|
k_scale=layer._k_scale,
|
||||||
v_scale=layer._v_scale,
|
v_scale=layer._v_scale,
|
||||||
|
Reference in New Issue
Block a user