mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 04:44:13 +08:00
[ROCm] Initial AITER Integration for mha_bwd asm kernels (#152630)
Generates AITER plumbing via cmake. Calls into fav3 asm bwd CK kernels. Update submodule composable kernel for this change Pull Request resolved: https://github.com/pytorch/pytorch/pull/152630 Approved by: https://github.com/xw285cornell, https://github.com/yoyoyocmu
This commit is contained in:
committed by
PyTorch MergeBot
parent
f40efde2a4
commit
b5ce77c1f5
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -129,3 +129,6 @@
|
||||
[submodule "third_party/flash-attention"]
|
||||
path = third_party/flash-attention
|
||||
url = https://github.com/Dao-AILab/flash-attention.git
|
||||
[submodule "third_party/aiter"]
|
||||
path = third_party/aiter
|
||||
url = https://github.com/ROCm/aiter.git
|
||||
|
@ -193,6 +193,10 @@ if(USE_FLASH_ATTENTION)
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck)
|
||||
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
|
||||
# FAv3 Generation
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
|
||||
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
|
||||
endif()
|
||||
endif()
|
||||
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
|
||||
@ -392,6 +396,7 @@ if(USE_ROCM)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
|
||||
# Next two lines are needed because TunableOp uses third-party/fmt
|
||||
|
@ -1,7 +1,7 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -10,8 +10,8 @@ if(ret AND NOT ret EQUAL 0)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 600 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
@ -20,14 +20,14 @@ if(ret AND NOT ret EQUAL 0)
|
||||
endif()
|
||||
|
||||
# Generate the files for both fwd and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 600 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,20 @@
|
||||
include(CMakePrintHelpers)
|
||||
|
||||
# Generate AITER/CK Asm code
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/py_itfs_cu/fmha_v3_bwd_kernel_generate.py --receipt 1 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Failed to generate FAv3 CK Kernels")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/csrc/cpp_itfs/mha_bwd_generate.py --receipt 3 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
|
||||
# Change file extensions to .hip
|
||||
execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done")
|
@ -453,4 +453,5 @@ struct fmha_bwd_traits
|
||||
bool is_deterministic;
|
||||
// TODO: padding check is inside this api
|
||||
};
|
||||
template <int Version = 2>
|
||||
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
|
||||
|
@ -3,6 +3,7 @@
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <mha_bwd.h>
|
||||
#include <fmha_bwd.hpp>
|
||||
#include <mask.hpp>
|
||||
|
||||
@ -28,6 +29,26 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
|
||||
deterministic};
|
||||
}
|
||||
|
||||
|
||||
|
||||
aiter::mha_bwd_traits get_mha_bwd_traits(fmha_bwd_traits t, mask_info mask)
|
||||
{
|
||||
return aiter::mha_bwd_traits(t.hdim_q,
|
||||
t.hdim_v,
|
||||
t.data_type,
|
||||
t.is_group_mode,
|
||||
mask.type,
|
||||
t.bias_type,
|
||||
t.has_dbias,
|
||||
t.has_dropout,
|
||||
t.is_store_randval,
|
||||
t.is_deterministic,
|
||||
true, // use_ext_asm
|
||||
true, // is_v3_atomic_fp32,
|
||||
1); // how_v3_bf16_cvt
|
||||
|
||||
}
|
||||
|
||||
fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
||||
// sizes
|
||||
const int b,
|
||||
@ -101,11 +122,11 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask,
|
||||
ck_tile::index_t stride_dv = dv.stride(1);
|
||||
ck_tile::index_t nhead_stride_dv = dv.stride(2);
|
||||
|
||||
// dq_acc: (split, batch_size, seqlen_q, nheads, hdim)
|
||||
// dq_acc: (split, batch_size, nheads, seqlen_q, hdim)
|
||||
ck_tile::index_t split_stride_dq_acc = dq_acc.stride(0);
|
||||
ck_tile::index_t batch_stride_dq_acc = dq_acc.stride(1);
|
||||
ck_tile::index_t stride_dq_acc = dq_acc.stride(2);
|
||||
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(3);
|
||||
ck_tile::index_t stride_dq_acc = dq_acc.stride(3);
|
||||
ck_tile::index_t nhead_stride_dq_acc = dq_acc.stride(2);
|
||||
|
||||
// bias: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
void *attn_bias_ptr = nullptr;
|
||||
@ -351,11 +372,11 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
at::Tensor dq_accum;
|
||||
|
||||
if (!deterministic) {
|
||||
dq_accum = at::zeros({1, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
|
||||
dq_accum = at::zeros({1, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
const ck_tile::index_t kN0 = head_size_8x <= 128 ? 128 : 64;
|
||||
const ck_tile::index_t nsplits = ck_tile::integer_divide_ceil(seqlen_k, kN0);
|
||||
dq_accum = at::zeros({nsplits, batch_size, seqlen_q, num_heads, head_size_8x}, opts.dtype(at::kFloat));
|
||||
dq_accum = at::zeros({nsplits, batch_size, num_heads, seqlen_q, head_size_8x}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
@ -376,14 +397,6 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
if (seqlen_q > 0) {
|
||||
ck_tile::stream_config stream_config{stream};
|
||||
dq.zero_(); // ck use atomic operation on dq
|
||||
auto traits =
|
||||
get_ck_fmha_bwd_traits(mask,
|
||||
q_dtype_str,
|
||||
head_size_8x,
|
||||
is_dropout,
|
||||
attn_bias_.has_value(),
|
||||
deterministic,
|
||||
bias_requires_grad);
|
||||
|
||||
auto args =
|
||||
get_ck_fmha_bwd_args(
|
||||
@ -411,7 +424,23 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x
|
||||
softmax_scale,
|
||||
p_dropout,
|
||||
drop_seed_offset);
|
||||
float t = fmha_bwd(traits, args, stream_config);
|
||||
|
||||
float t = aiter::mha_bwd(args,
|
||||
stream_config,
|
||||
q_dtype_str,
|
||||
false, // is_group_mode
|
||||
mask.type,
|
||||
attn_bias_.has_value() ? bias_enum::elementwise_bias : bias_enum::no_bias,
|
||||
bias_requires_grad,
|
||||
false, // is_store_randval
|
||||
deterministic,
|
||||
true, // use_ext_asm
|
||||
true, // is_v3_atomic_fp32
|
||||
1); // how_v3_bf16_cvt
|
||||
|
||||
|
||||
|
||||
|
||||
TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd");
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
|
1
third_party/aiter
vendored
Submodule
1
third_party/aiter
vendored
Submodule
Submodule third_party/aiter added at 01aae101b9
2
third_party/composable_kernel
vendored
2
third_party/composable_kernel
vendored
Submodule third_party/composable_kernel updated: 8086bbe3a7...434d19f696
Reference in New Issue
Block a user