From b5ce77c1f5964293299eb1366f341872a4e47fa6 Mon Sep 17 00:00:00 2001 From: Andy Lugo Date: Tue, 1 Jul 2025 02:53:22 +0000 Subject: [PATCH] [ROCm] Initial AITER Integration for mha_bwd asm kernels (#152630) Generates AITER plumbing via cmake. Calls into fav3 asm bwd CK kernels. Update submodule composable kernel for this change Pull Request resolved: https://github.com/pytorch/pytorch/pull/152630 Approved by: https://github.com/xw285cornell, https://github.com/yoyoyocmu --- .gitmodules | 3 + aten/src/ATen/CMakeLists.txt | 5 ++ .../hip/flash_attn/ck/CMakeLists.txt | 12 ++-- .../hip/flash_attn/ck/fav_v3/CMakeLists.txt | 20 +++++++ .../hip/flash_attn/ck/fmha_bwd.hpp | 1 + .../hip/flash_attn/ck/mha_bwd_ck.hip | 57 ++++++++++++++----- third_party/aiter | 1 + third_party/composable_kernel | 2 +- 8 files changed, 80 insertions(+), 21 deletions(-) create mode 100644 aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt create mode 160000 third_party/aiter diff --git a/.gitmodules b/.gitmodules index b22e1026d6a7..4eb6e511127d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -129,3 +129,6 @@ [submodule "third_party/flash-attention"] path = third_party/flash-attention url = https://github.com/Dao-AILab/flash-attention.git +[submodule "third_party/aiter"] + path = third_party/aiter + url = https://github.com/ROCm/aiter.git diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index c9cfd74b501e..af8fea252947 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -193,6 +193,10 @@ if(USE_FLASH_ATTENTION) add_subdirectory(native/transformers/hip/flash_attn/ck) file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) + # FAv3 Generation + add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) + file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) endif() endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") @@ -392,6 +396,7 @@ if(USE_ROCM) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include) _pytorch_rocm_generate_ck_conf() # Next two lines are needed because TunableOp uses third-party/fmt diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt index a72911cd510e..b30c39340036 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt @@ -1,7 +1,7 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) @@ -10,8 +10,8 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py - --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + --api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) @@ -20,14 +20,14 @@ if(ret AND NOT ret EQUAL 0) endif() # Generate the files for both fwd and bwd -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret ) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt new file mode 100644 index 000000000000..cccf026690dc --- /dev/null +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt @@ -0,0 +1,20 @@ +include(CMakePrintHelpers) + +# Generate AITER/CK Asm code +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py --receipt 1 --output_dir ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret +) + +if(ret AND NOT ret EQUAL 0) + message( FATAL_ERROR "Failed to generate FAv3 CK Kernels") +endif() + +execute_process( + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir ${CMAKE_CURRENT_LIST_DIR} + RESULT_VARIABLE ret +) + + +# Change file extensions to .hip +execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done") diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp index 38ec2ef20c5c..affa40619b59 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd.hpp @@ -453,4 +453,5 @@ struct fmha_bwd_traits bool is_deterministic; // TODO: padding check is inside this api }; +template float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index 0a99d5a81568..854ac950a867 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -3,6 +3,7 @@ ******************************************************************************/ #include +#include #include #include @@ -28,6 +29,26 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, deterministic}; } + + +aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask) +{ + return aiter::mha_bwd_traits(t.hdim_q, + t.hdim_v, + t.data_type, + t.is_group_mode, + mask.type, + t.bias_type, + t.has_dbias, + t.has_dropout, + t.is_store_randval, + t.is_deterministic, + true, // use_ext_asm + true, // is_v3_atomic_fp32, + 1); // how_v3_bf16_cvt + +} + fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, // sizes const int b, @@ -101,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t nhead_stride_dv = dv.stride(2); - // dq_acc: (split, batch_size, seqlen_q, nheads, hdim) + // dq_acc: (split, batch_size, nheads, seqlen_q, hdim) ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); - ck_tile::index_t stride_dq_acc = dq_acc.stride(2); - ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); + ck_tile::index_t stride_dq_acc = dq_acc.stride(3); + ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2); // bias: (batch_size, nheads, seqlen_q, seqlen_k) void *attn_bias_ptr = nullptr; @@ -351,11 +372,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x at::Tensor dq_accum; if (!deterministic) { - dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); } else { const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); - dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); + dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat)); } at::Tensor dk_expanded, dv_expanded; @@ -376,14 +397,6 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x if (seqlen_q > 0) { ck_tile::stream_config stream_config{stream}; dq.zero_(); // ck use atomic operation on dq - auto traits = - get_ck_fmha_bwd_traits(mask, - q_dtype_str, - head_size_8x, - is_dropout, - attn_bias_.has_value(), - deterministic, - bias_requires_grad); auto args = get_ck_fmha_bwd_args( @@ -411,7 +424,23 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x softmax_scale, p_dropout, drop_seed_offset); - float t = fmha_bwd(traits, args, stream_config); + + float t = aiter::mha_bwd(args, + stream_config, + q_dtype_str, + false, // is_group_mode + mask.type, + attn_bias_.has_value() ? bias_enum::elementwise_bias : bias_enum::no_bias, + bias_requires_grad, + false, // is_store_randval + deterministic, + true, // use_ext_asm + true, // is_v3_atomic_fp32 + 1); // how_v3_bf16_cvt + + + + TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. diff --git a/third_party/aiter b/third_party/aiter new file mode 160000 index 000000000000..01aae101b9e5 --- /dev/null +++ b/third_party/aiter @@ -0,0 +1 @@ +Subproject commit 01aae101b9e5e94d6c16a9514c9fb8df99c93150 diff --git a/third_party/composable_kernel b/third_party/composable_kernel index 8086bbe3a78d..434d19f696da 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit 8086bbe3a78d931eb96fe12fdc014082e18d18d3 +Subproject commit 434d19f696da62c12b5372b32cbc9ba968588d7e