[ROCm][FP8][Kernel] FP8 quantization fused into Custom Paged Attention (#17139)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-05-07 10:12:35 -04:00
committed by GitHub
parent 7377dd0307
commit 32aa74c09c
4 changed files with 73 additions and 43 deletions

View File

@ -1287,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
const auto num_heads = gridDim.x;
const auto head_idx = blockIdx.x;
const auto seq_idx = blockIdx.y;
@ -1465,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
const float out_scale =
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
acc *= inv_global_exp_sum;
acc *= out_scale;
const int64_t query_start_off = static_cast<int64_t>(
query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx);
OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE +
@ -1548,7 +1550,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int* __restrict__ query_start_loc_ptr, // [num_seqs]
const int max_num_partitions) {
const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) {
UNREACHABLE_CODE
}
// clang-format on
@ -1582,7 +1584,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
PARTITION_SIZE, NPAR_LOOPS> \
<<<reduce_grid, reduce_block, 0, stream>>>( \
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \
context_lens_ptr, query_start_loc_ptr, max_num_partitions);
context_lens_ptr, query_start_loc_ptr, max_num_partitions, \
fp8_out_scale_ptr);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE_OLD,
@ -1594,7 +1597,7 @@ void paged_attention_custom_launcher(
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int max_context_len,
const std::optional<torch::Tensor>& alibi_slopes, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale) {
int num_seqs = block_tables.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
@ -1626,6 +1629,11 @@ void paged_attention_custom_launcher(
int* context_lens_ptr = context_lens.data_ptr<int>();
const float* k_scale_ptr = reinterpret_cast<const float*>(k_scale.data_ptr());
const float* v_scale_ptr = reinterpret_cast<const float*>(v_scale.data_ptr());
// NOTE: fp8_out_scale is optional.
const auto fp8_out_scale_ptr =
fp8_out_scale
? static_cast<const float*>(fp8_out_scale.value().data_ptr())
: nullptr;
OUTT* out_ptr = reinterpret_cast<OUTT*>(out.data_ptr());
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
@ -1736,29 +1744,50 @@ void paged_attention_custom_launcher(
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \
ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE, ALIBI_ENABLED>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, query_start_loc, \
max_context_len, alibi_slopes, k_scale, v_scale);
max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale);
#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
PSIZE) \
OUTT, PSIZE) \
if (alibi_slopes) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
true); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \
false); \
}
#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t, 256); \
} else { \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \
256); \
}
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \
CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
@ -1795,7 +1824,8 @@ void paged_attention(
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
torch::Tensor& v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale) {
// clang-format on
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {

View File

@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b,
void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount);
void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc,
int64_t block_size, int64_t max_context_len,
const std::optional<torch::Tensor>& alibi_slopes,
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
const std::optional<torch::Tensor>& query_start_loc, int64_t block_size,
int64_t max_context_len, const std::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale);
torch::Tensor& v_scale, const c10::optional<torch::Tensor>& fp8_out_scale);

View File

@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" Tensor k_scale, Tensor v_scale) -> ()");
" Tensor k_scale, Tensor v_scale,"
" Tensor? fp8_out_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
}

View File

@ -117,13 +117,14 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
query_start_loc, block_size, max_seq_len,
alibi_slopes, kv_cache_dtype, k_scale,
v_scale)
v_scale, fp8_out_scale)
def mla_decode_kvcache_cpu(