[ROCm][Misc] Rename the context_len to seq_len in ROCm custom paged attention kernel (#22097)

Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
Charlie Fu
2025-08-09 01:15:06 -05:00
committed by GitHub
parent 9a0c5ded5a
commit b7c0942b65
3 changed files with 91 additions and 96 deletions

View File

@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx =
partition_idx * T_PAR_SIZE; // partition_size;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
if (partition_start_token_idx >= seq_len) {
return;
}
@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
// across 4 rows x 4 tokens per lane
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
// tokens
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
if constexpr (ALIBI_ENABLED) {
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
const int alibi_offset = local_token_idx - context_len + 1;
const int alibi_offset = local_token_idx - seq_len + 1;
for (int i = 0; i < 4; i++) {
d_out[token_depth][i] += alibi_slope * (alibi_offset + i);
}
@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) {
const float tmp = (local_token_idx + i < context_len)
? d_out[token_depth][i]
: -FLT_MAX;
const float tmp =
(local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 4; i++) {
const float tmp = (local_token_idx + i < context_len)
const float tmp = (local_token_idx + i < seq_len)
? __expf(d_out[token_depth][i] - qk_max)
: 0.0f;
d_out[token_depth][i] = tmp;
@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const auto partition_size = blockDim.x;
const auto max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
const int seq_len = seq_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
if (partition_start_token_idx >= seq_len) {
return;
}
// every 4 lanes fetch 4 different qheads
@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE;
if (warp_start_token_idx >= context_len) { // warp out of context
if (warp_start_token_idx >= seq_len) { // warp out of context
#pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) {
shared_qk_max[warpid][h] = -FLT_MAX;
@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
} else { // warp within context
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_seq_block = num_seq_blocks - 1;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// token id within partition
@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int global_token_idx = partition_start_token_idx + local_token_idx;
// fetch block number for k
const int block_idx = (global_token_idx < context_len)
const int block_idx = (global_token_idx < seq_len)
? global_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
// fetch k physical block number
// int32 physical_block_number leads to overflow when multiplied with
@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
(vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int lane4_token_idx = 4 * (global_token_idx >> 2);
if constexpr (ALIBI_ENABLED) {
const int alibi_offset = lane4_token_idx - context_len + 1;
const int alibi_offset = lane4_token_idx - seq_len + 1;
for (int h = 0; h < QHLOOP; h++) {
for (int i = 0; i < 4; i++) {
d_out[h][i] += alibi_slope[h] * (alibi_offset + i);
@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX;
for (int i = 0; i < 4; i++) {
qk_max[h] = (lane4_token_idx + i < context_len)
qk_max[h] = (lane4_token_idx + i < seq_len)
? fmaxf(qk_max[h], d_out[h][i])
: qk_max[h];
}
@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f;
for (int i = 0; i < 4; i++) {
d_out[h][i] = (lane4_token_idx + i < context_len)
d_out[h][i] = (lane4_token_idx + i < seq_len)
? __expf(d_out[h][i] - qk_max[h])
: 0.0f;
exp_sum[h] += d_out[h][i];
@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
}
}
if (warp_start_token_idx >= context_len) { // warp out of context
if (warp_start_token_idx >= seq_len) { // warp out of context
for (int qh = 0; qh < QHLOOP; qh++) {
for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0};
@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const auto warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq
const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
if (partition_start_token_idx >= seq_len) {
return;
}
@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
? dout[token_depth][i]
: -FLT_MAX;
const float tmp =
(local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + 2 * i < context_len)
const float tmp = (local_token_idx + 2 * i < seq_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx]; // length of a seq
const int seq_len = seq_lens[seq_idx]; // length of a seq
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
if (partition_start_token_idx >= seq_len) {
return;
}
@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
}
}
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int last_seq_block = num_seq_blocks - 1;
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int klocal_token_idx =
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
const int kblock_idx = (kglobal_token_idx < context_len)
const int kblock_idx = (kglobal_token_idx < seq_len)
? kglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
}
@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
const int vglobal_token_idx =
partition_start_token_idx + vlocal_token_idx;
const int vblock_idx = (vglobal_token_idx < context_len)
const int vblock_idx = (vglobal_token_idx < seq_len)
? vglobal_token_idx / BLOCK_SIZE
: last_ctx_block;
: last_seq_block;
vphysical_block_number[vtoken_depth][vblock_depth] =
block_table_seq[vblock_idx];
}
@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp =
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
(local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
qk_max = fmaxf(qk_max, tmp);
}
}
@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
const int local_token_idx = qkout_token_idx + token_depth * 16;
for (int i = 0; i < 8; i++) {
const float tmp = (local_token_idx + i < context_len)
const float tmp = (local_token_idx + i < seq_len)
? __expf(dout[token_depth][i] - qk_max)
: 0.0f;
dout[token_depth][i] = tmp;
@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
return;
}
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
const int seq_len = seq_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int warpid = threadIdx.x / WARP_SIZE;
__shared__ float shared_global_exp_sum;
@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
const int num_kv_heads,
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ seq_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@ -3077,8 +3074,8 @@ void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher(
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher(
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
// it mfma4 kernel also supports partition size 512
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
int num_seqs = block_tables.size(0);
@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi(
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
int* seq_lens_ptr = seq_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi(
const auto fp8_out_scale_ptr = nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
constexpr int PARTITION_SIZE = 256;
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi(
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
} else { \
paged_attention_custom_launcher_navi< \
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale); \
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
max_seq_len, alibi_slopes, k_scale, v_scale); \
}
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
@ -3502,9 +3497,9 @@ void paged_attention(
int64_t num_kv_heads,
double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& seq_lens, // [num_seqs]
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
int64_t block_size, int64_t max_context_len,
int64_t block_size, int64_t max_seq_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale,

View File

@ -15,8 +15,8 @@ void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
torch::Tensor& block_tables, torch::Tensor& seq_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);

View File

@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor query, Tensor key_cache,"
" Tensor value_cache, int num_kv_heads,"
" float scale, Tensor block_tables,"
" Tensor context_lens,"
" Tensor seq_lens,"
" Tensor? query_start_loc,"
" int block_size,"
" int max_context_len,"
" int max_seq_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale,"