[ROCm][Misc] Rename the context_len to seq_len in ROCm custom paged attention kernel (#22097)
Signed-off-by: charlifu <charlifu@amd.com>
This commit is contained in:
@ -270,7 +270,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -304,12 +304,12 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
|
||||
const int partition_start_token_idx =
|
||||
partition_idx * T_PAR_SIZE; // partition_size;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -361,8 +361,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// output layout from QKmfma : QH16xT4x4 16 qheads across 16 lanes, 16 tokens
|
||||
// across 4 rows x 4 tokens per lane
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -373,9 +373,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -476,9 +476,9 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// tokens
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -554,7 +554,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
if constexpr (ALIBI_ENABLED) {
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
const int alibi_offset = local_token_idx - context_len + 1;
|
||||
const int alibi_offset = local_token_idx - seq_len + 1;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[token_depth][i] += alibi_slope * (alibi_offset + i);
|
||||
}
|
||||
@ -568,9 +568,8 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
? d_out[token_depth][i]
|
||||
: -FLT_MAX;
|
||||
const float tmp =
|
||||
(local_token_idx + i < seq_len) ? d_out[token_depth][i] : -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -582,7 +581,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
const float tmp = (local_token_idx + i < seq_len)
|
||||
? __expf(d_out[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
d_out[token_depth][i] = tmp;
|
||||
@ -780,7 +779,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -809,10 +808,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const auto partition_size = blockDim.x;
|
||||
const auto max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int partition_start_token_idx = partition_idx * partition_size;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
return;
|
||||
}
|
||||
// every 4 lanes fetch 4 different qheads
|
||||
@ -855,7 +854,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int warp_start_token_idx =
|
||||
partition_start_token_idx + warpid * WARP_SIZE;
|
||||
|
||||
if (warp_start_token_idx >= context_len) { // warp out of context
|
||||
if (warp_start_token_idx >= seq_len) { // warp out of context
|
||||
#pragma unroll
|
||||
for (int h = 0; h < GQA_RATIO4; h++) {
|
||||
shared_qk_max[warpid][h] = -FLT_MAX;
|
||||
@ -863,8 +862,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
} else { // warp within context
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
|
||||
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
// token id within partition
|
||||
@ -873,9 +872,9 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int global_token_idx = partition_start_token_idx + local_token_idx;
|
||||
|
||||
// fetch block number for k
|
||||
const int block_idx = (global_token_idx < context_len)
|
||||
const int block_idx = (global_token_idx < seq_len)
|
||||
? global_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
|
||||
// fetch k physical block number
|
||||
// int32 physical_block_number leads to overflow when multiplied with
|
||||
@ -888,7 +887,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int b = 0; b < VBLOCKS; b++) {
|
||||
const int vblock_idx = warp_start_block_idx + b;
|
||||
const int vblock_idx_ctx =
|
||||
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
|
||||
(vblock_idx <= last_seq_block) ? vblock_idx : last_seq_block;
|
||||
vphysical_blocks[b] = block_table[vblock_idx_ctx];
|
||||
}
|
||||
|
||||
@ -1057,7 +1056,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int lane4_token_idx = 4 * (global_token_idx >> 2);
|
||||
|
||||
if constexpr (ALIBI_ENABLED) {
|
||||
const int alibi_offset = lane4_token_idx - context_len + 1;
|
||||
const int alibi_offset = lane4_token_idx - seq_len + 1;
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[h][i] += alibi_slope[h] * (alibi_offset + i);
|
||||
@ -1070,7 +1069,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
qk_max[h] = -FLT_MAX;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
qk_max[h] = (lane4_token_idx + i < context_len)
|
||||
qk_max[h] = (lane4_token_idx + i < seq_len)
|
||||
? fmaxf(qk_max[h], d_out[h][i])
|
||||
: qk_max[h];
|
||||
}
|
||||
@ -1101,7 +1100,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
for (int h = 0; h < QHLOOP; h++) {
|
||||
exp_sum[h] = 0.0f;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
d_out[h][i] = (lane4_token_idx + i < context_len)
|
||||
d_out[h][i] = (lane4_token_idx + i < seq_len)
|
||||
? __expf(d_out[h][i] - qk_max[h])
|
||||
: 0.0f;
|
||||
exp_sum[h] += d_out[h][i];
|
||||
@ -1181,7 +1180,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
if (warp_start_token_idx >= context_len) { // warp out of context
|
||||
if (warp_start_token_idx >= seq_len) { // warp out of context
|
||||
for (int qh = 0; qh < QHLOOP; qh++) {
|
||||
for (int vh = 0; vh < VHELOOP; vh++) {
|
||||
vout_shared[qh][vh][laneid][warpid] = {0};
|
||||
@ -1279,7 +1278,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -1293,8 +1292,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const auto warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -1581,7 +1580,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -1615,11 +1614,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx]; // length of a seq
|
||||
const int seq_len = seq_lens[seq_idx]; // length of a seq
|
||||
|
||||
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -1715,8 +1714,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -1727,9 +1726,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -1781,9 +1780,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
vblock_depth * BLOCK_SIZE;
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -1836,9 +1835,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp = (local_token_idx + 2 * i < context_len)
|
||||
? dout[token_depth][i]
|
||||
: -FLT_MAX;
|
||||
const float tmp =
|
||||
(local_token_idx + 2 * i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -1848,7 +1846,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp = (local_token_idx + 2 * i < context_len)
|
||||
const float tmp = (local_token_idx + 2 * i < seq_len)
|
||||
? __expf(dout[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
dout[token_depth][i] = tmp;
|
||||
@ -2019,7 +2017,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2046,7 +2044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -2060,8 +2058,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -2349,7 +2347,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2382,11 +2380,11 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
|
||||
const int max_num_partitions = gridDim.y;
|
||||
|
||||
const int context_len = context_lens[seq_idx]; // length of a seq
|
||||
const int seq_len = seq_lens[seq_idx]; // length of a seq
|
||||
|
||||
const int partition_start_token_idx = partition_idx * T_PAR_SIZE;
|
||||
// exit if partition is out of context for seq
|
||||
if (partition_start_token_idx >= context_len) {
|
||||
if (partition_start_token_idx >= seq_len) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -2482,8 +2480,8 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
|
||||
const int last_ctx_block = num_context_blocks - 1;
|
||||
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
|
||||
const int last_seq_block = num_seq_blocks - 1;
|
||||
|
||||
const int* block_table_seq = block_tables + seq_idx * max_num_blocks_per_seq;
|
||||
|
||||
@ -2494,9 +2492,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kblock_idx = (kglobal_token_idx < context_len)
|
||||
const int kblock_idx = (kglobal_token_idx < seq_len)
|
||||
? kglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
kphysical_block_number[token_depth] = block_table_seq[kblock_idx];
|
||||
}
|
||||
|
||||
@ -2548,9 +2546,9 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
rowid * VTOKENS_PER_LANE + vblock_depth * BLOCK_SIZE;
|
||||
const int vglobal_token_idx =
|
||||
partition_start_token_idx + vlocal_token_idx;
|
||||
const int vblock_idx = (vglobal_token_idx < context_len)
|
||||
const int vblock_idx = (vglobal_token_idx < seq_len)
|
||||
? vglobal_token_idx / BLOCK_SIZE
|
||||
: last_ctx_block;
|
||||
: last_seq_block;
|
||||
vphysical_block_number[vtoken_depth][vblock_depth] =
|
||||
block_table_seq[vblock_idx];
|
||||
}
|
||||
@ -2604,7 +2602,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp =
|
||||
(local_token_idx + i < context_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
(local_token_idx + i < seq_len) ? dout[token_depth][i] : -FLT_MAX;
|
||||
qk_max = fmaxf(qk_max, tmp);
|
||||
}
|
||||
}
|
||||
@ -2614,7 +2612,7 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
for (int token_depth = 0; token_depth < TLOOP; token_depth++) {
|
||||
const int local_token_idx = qkout_token_idx + token_depth * 16;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
const float tmp = (local_token_idx + i < context_len)
|
||||
const float tmp = (local_token_idx + i < seq_len)
|
||||
? __expf(dout[token_depth][i] - qk_max)
|
||||
: 0.0f;
|
||||
dout[token_depth][i] = tmp;
|
||||
@ -2751,7 +2749,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
// head_size, block_size]
|
||||
const int num_kv_heads, const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -2778,7 +2776,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
// max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
|
||||
// max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
const auto num_heads = gridDim.x;
|
||||
@ -2792,8 +2790,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
const int context_len = context_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
|
||||
const int seq_len = seq_lens[seq_idx];
|
||||
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
|
||||
__shared__ float shared_global_exp_sum;
|
||||
@ -2980,7 +2978,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -3007,7 +3005,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
|
||||
const int num_kv_heads,
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
@ -3031,7 +3029,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
|
||||
const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
|
||||
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
const int* __restrict__ seq_lens, // [num_seqs]
|
||||
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
|
||||
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
|
||||
UNREACHABLE_CODE
|
||||
@ -3046,7 +3044,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
@ -3057,18 +3055,17 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
|
||||
GQA_RATIO> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
|
||||
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
|
||||
block_tables_ptr, seq_lens_ptr, query_start_loc_ptr, \
|
||||
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
|
||||
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
|
||||
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
|
||||
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
|
||||
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
|
||||
fp8_out_scale_ptr);
|
||||
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
|
||||
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
|
||||
PARTITION_SIZE, NPAR_LOOPS> \
|
||||
<<<reduce_grid, reduce_block, 0, stream>>>( \
|
||||
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \
|
||||
query_start_loc_ptr, max_num_partitions, fp8_out_scale_ptr);
|
||||
|
||||
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
|
||||
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
|
||||
@ -3077,8 +3074,8 @@ void paged_attention_custom_launcher(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
@ -3109,7 +3106,7 @@ void paged_attention_custom_launcher(
|
||||
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
// NOTE: fp8_out_scale is optional.
|
||||
@ -3119,13 +3116,12 @@ void paged_attention_custom_launcher(
|
||||
: nullptr;
|
||||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
|
||||
|
||||
// partition size is fixed at 256 since both mfma4 and mfma16 kernels support
|
||||
// it mfma4 kernel also supports partition size 512
|
||||
constexpr int PARTITION_SIZE = 256;
|
||||
const int max_num_partitions =
|
||||
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
const int gqa_ratio = num_heads / num_kv_heads;
|
||||
assert(num_heads % num_kv_heads == 0);
|
||||
assert(head_size == HEAD_SIZE);
|
||||
@ -3234,8 +3230,8 @@ void paged_attention_custom_launcher_navi(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, const int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale) {
|
||||
int num_seqs = block_tables.size(0);
|
||||
@ -3263,7 +3259,7 @@ void paged_attention_custom_launcher_navi(
|
||||
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
|
||||
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
int* seq_lens_ptr = seq_lens.data_ptr<int>();
|
||||
|
||||
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
|
||||
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
|
||||
@ -3271,11 +3267,10 @@ void paged_attention_custom_launcher_navi(
|
||||
const auto fp8_out_scale_ptr = nullptr;
|
||||
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
|
||||
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
|
||||
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE);
|
||||
|
||||
constexpr int PARTITION_SIZE = 256;
|
||||
const int max_num_partitions =
|
||||
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
|
||||
const int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
|
||||
const int gqa_ratio = num_heads / num_kv_heads;
|
||||
assert(num_heads % num_kv_heads == 0);
|
||||
assert(head_size == HEAD_SIZE);
|
||||
@ -3407,14 +3402,14 @@ void paged_attention_custom_launcher_navi(
|
||||
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); \
|
||||
} else { \
|
||||
paged_attention_custom_launcher_navi< \
|
||||
T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, ALIBI_ENABLED>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
|
||||
max_context_len, alibi_slopes, k_scale, v_scale); \
|
||||
num_kv_heads, scale, block_tables, seq_lens, query_start_loc, \
|
||||
max_seq_len, alibi_slopes, k_scale, v_scale); \
|
||||
}
|
||||
|
||||
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
|
||||
@ -3502,9 +3497,9 @@ void paged_attention(
|
||||
int64_t num_kv_heads,
|
||||
double scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
const std::optional<torch::Tensor>& query_start_loc, // [num_seqs]
|
||||
int64_t block_size, int64_t max_context_len,
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale,
|
||||
|
@ -15,8 +15,8 @@ void paged_attention(
|
||||
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& context_lens,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens,
|
||||
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
|
||||
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
int64_t max_seq_len, const std::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
|
||||
torch::Tensor& v_scale, const std::optional<torch::Tensor>& fp8_out_scale);
|
||||
|
@ -41,10 +41,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
|
||||
" Tensor query, Tensor key_cache,"
|
||||
" Tensor value_cache, int num_kv_heads,"
|
||||
" float scale, Tensor block_tables,"
|
||||
" Tensor context_lens,"
|
||||
" Tensor seq_lens,"
|
||||
" Tensor? query_start_loc,"
|
||||
" int block_size,"
|
||||
" int max_context_len,"
|
||||
" int max_seq_len,"
|
||||
" Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor k_scale, Tensor v_scale,"
|
||||
|
Reference in New Issue
Block a user