[ROCm] Initial AITER Integration for mha_bwd asm kernels (#152630)

Generates AITER plumbing via cmake. Calls into fav3 asm bwd CK kernels.

Update submodule composable kernel for this change

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152630
Approved by: https://github.com/xw285cornell, https://github.com/yoyoyocmu
This commit is contained in:
Andy Lugo
2025-07-01 02:53:22 +00:00
committed by PyTorch MergeBot
parent f40efde2a4
commit b5ce77c1f5
8 changed files with 80 additions and 21 deletions

3
.gitmodules vendored
View File

@ -129,3 +129,6 @@
[submodule "third_party/flash-attention"] [submodule "third_party/flash-attention"]
path = third_party/flash-attention path = third_party/flash-attention
url = https://github.com/Dao-AILab/flash-attention.git url = https://github.com/Dao-AILab/flash-attention.git
[submodule "third_party/aiter"]
path = third_party/aiter
url = https://github.com/ROCm/aiter.git

View File

@ -193,6 +193,10 @@ if(USE_FLASH_ATTENTION)
add_subdirectory(native/transformers/hip/flash_attn/ck) add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
# FAv3 Generation
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
endif() endif()
endif() endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
@ -392,6 +396,7 @@ if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf() _pytorch_rocm_generate_ck_conf()
# Next two lines are needed because TunableOp uses third-party/fmt # Next two lines are needed because TunableOp uses third-party/fmt

View File

@ -1,7 +1,7 @@
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt --api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret RESULT_VARIABLE ret
) )
@ -10,8 +10,8 @@ if(ret AND NOT ret EQUAL 0)
endif() endif()
execute_process( execute_process(
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt --api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
RESULT_VARIABLE ret RESULT_VARIABLE ret
) )
@ -20,14 +20,14 @@ if(ret AND NOT ret EQUAL 0)
endif() endif()
# Generate the files for both fwd and bwd # Generate the files for both fwd and bwd
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
) )
if(ret AND NOT ret EQUAL 0) if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
endif() endif()
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret RESULT_VARIABLE ret
) )

View File

@ -0,0 +1,20 @@
include(CMakePrintHelpers)
# Generate AITER/CK Asm code
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py --receipt 1 --output_dir ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
message( FATAL_ERROR "Failed to generate FAv3 CK Kernels")
endif()
execute_process(
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir ${CMAKE_CURRENT_LIST_DIR}
RESULT_VARIABLE ret
)
# Change file extensions to .hip
execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done")

View File

@ -453,4 +453,5 @@ struct fmha_bwd_traits
bool is_deterministic; bool is_deterministic;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
template <int Version = 2>
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&); float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@ -3,6 +3,7 @@
******************************************************************************/ ******************************************************************************/
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp> #include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
#include <mha_bwd.h>
#include <fmha_bwd.hpp> #include <fmha_bwd.hpp>
#include <mask.hpp> #include <mask.hpp>
@ -28,6 +29,26 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
deterministic}; deterministic};
} }
aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask)
{
return aiter::mha_bwd_traits(t.hdim_q,
t.hdim_v,
t.data_type,
t.is_group_mode,
mask.type,
t.bias_type,
t.has_dbias,
t.has_dropout,
t.is_store_randval,
t.is_deterministic,
true, // use_ext_asm
true, // is_v3_atomic_fp32,
1); // how_v3_bf16_cvt
}
fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
// sizes // sizes
const int b, const int b,
@ -101,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
ck_tile::index_t stride_dv = dv.stride(1); ck_tile::index_t stride_dv = dv.stride(1);
ck_tile::index_t nhead_stride_dv = dv.stride(2); ck_tile::index_t nhead_stride_dv = dv.stride(2);
// dq_acc: (split, batch_size, seqlen_q, nheads, hdim) // dq_acc: (split, batch_size, nheads, seqlen_q, hdim)
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0); ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1); ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
ck_tile::index_t stride_dq_acc = dq_acc.stride(2); ck_tile::index_t stride_dq_acc = dq_acc.stride(3);
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3); ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
// bias: (batch_size, nheads, seqlen_q, seqlen_k) // bias: (batch_size, nheads, seqlen_q, seqlen_k)
void *attn_bias_ptr = nullptr; void *attn_bias_ptr = nullptr;
@ -351,11 +372,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
at::Tensor dq_accum; at::Tensor dq_accum;
if (!deterministic) { if (!deterministic) {
dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
} else { } else {
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64; const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0); const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat)); dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
} }
at::Tensor dk_expanded, dv_expanded; at::Tensor dk_expanded, dv_expanded;
@ -376,14 +397,6 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
if (seqlen_q > 0) { if (seqlen_q > 0) {
ck_tile::stream_config stream_config{stream}; ck_tile::stream_config stream_config{stream};
dq.zero_(); // ck use atomic operation on dq dq.zero_(); // ck use atomic operation on dq
auto traits =
get_ck_fmha_bwd_traits(mask,
q_dtype_str,
head_size_8x,
is_dropout,
attn_bias_.has_value(),
deterministic,
bias_requires_grad);
auto args = auto args =
get_ck_fmha_bwd_args( get_ck_fmha_bwd_args(
@ -411,7 +424,23 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
softmax_scale, softmax_scale,
p_dropout, p_dropout,
drop_seed_offset); drop_seed_offset);
float t = fmha_bwd(traits, args, stream_config);
float t = aiter::mha_bwd(args,
stream_config,
q_dtype_str,
false, // is_group_mode
mask.type,
attn_bias_.has_value() ? bias_enum::elementwise_bias : bias_enum::no_bias,
bias_requires_grad,
false, // is_store_randval
deterministic,
true, // use_ext_asm
true, // is_v3_atomic_fp32
1); // how_v3_bf16_cvt
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
} else { } else {
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.

1
third_party/aiter vendored Submodule

Submodule third_party/aiter added at 01aae101b9