mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Update ck (#144799)
Updates the CK version and re-implements kernel generation Pull Request resolved: https://github.com/pytorch/pytorch/pull/144799 Approved by: https://github.com/jianyuh
This commit is contained in:
committed by
PyTorch MergeBot
parent
a00d2b5144
commit
5d675de754
7
.gitignore
vendored
7
.gitignore
vendored
@ -125,6 +125,13 @@ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
|
||||
torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
|
||||
torch/version.py
|
||||
minifier_launcher.py
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_convert*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fwd_blob*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/bwd_blob*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_api*
|
||||
aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_api*
|
||||
# Root level file used in CI to specify certain env configs.
|
||||
# E.g., see .circleci/config.yaml
|
||||
env
|
||||
|
@ -183,6 +183,8 @@ if(USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
|
||||
message(STATUS "Generating CK kernel instances...")
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck)
|
||||
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
|
||||
endif()
|
||||
|
@ -0,0 +1,63 @@
|
||||
# generate a list of kernels, but not actually emit files at config stage
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of FWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py
|
||||
--api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate a list of BWD kernels via Python.")
|
||||
endif()
|
||||
|
||||
# Generate the files for both fwd and bwd
|
||||
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.")
|
||||
endif()
|
||||
|
||||
execute_process(COMMAND python3 ${CMAKE_CURRENT_LIST_DIR}/../../../../../../../../third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR}
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to generate BWD kernels.")
|
||||
endif()
|
||||
|
||||
# Change make_kernel to make_kernel_pt for fwd
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the fwd pass")
|
||||
endif()
|
||||
|
||||
# Change make_kernel to make_kernel_pt for bwd
|
||||
execute_process(
|
||||
COMMAND bash -c "${CMAKE_CURRENT_LIST_DIR}/add_make_kernel_pt.sh ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt"
|
||||
RESULT_VARIABLE ret)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change make_kernel to make_kernel_pt for the bwd pass")
|
||||
endif()
|
||||
|
||||
# Change file extensions to .hip
|
||||
execute_process(COMMAND bash -c "for file in ${CMAKE_CURRENT_LIST_DIR}/*.cpp; do mv -- \"$file\" \"\${file%.cpp}.hip\"; done"
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "CK Tile FMHA FAILED to change the generated instances extensions from .cpp to .hpp")
|
||||
endif()
|
30
aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh
Executable file
30
aten/src/ATen/native/transformers/hip/flash_attn/ck/add_make_kernel_pt.sh
Executable file
@ -0,0 +1,30 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Check if the input file is provided
|
||||
if [ "$#" -ne 1 ]; then
|
||||
echo "Usage: $0 <file_list.txt>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Assign the input file to a variable
|
||||
file_list=$1
|
||||
|
||||
# Check if the file exists
|
||||
if [ ! -f "$file_list" ]; then
|
||||
echo "Error: File '$file_list' not found!"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Loop through each line in the file list
|
||||
while IFS= read -r file; do
|
||||
# Check if the file exists in the current directory
|
||||
if [ -f "$file" ]; then
|
||||
# Use sed to replace "make_kernel" with "make_kernel_pt" in place
|
||||
sed -i 's/make_kernel/make_kernel_pt/g' "$file"
|
||||
echo "Updated: $file"
|
||||
else
|
||||
echo "Skipping: $file (not found)"
|
||||
fi
|
||||
done < "$file_list"
|
||||
|
||||
echo "Replacement completed."
|
@ -15,11 +15,19 @@
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
struct FmhaBwdFp16
|
||||
{
|
||||
};
|
||||
|
||||
struct FmhaBwdBf16
|
||||
{
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct FmhaBwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::half_t>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdFp16>
|
||||
{
|
||||
using QDataType = ck_tile::half_t;
|
||||
using KDataType = ck_tile::half_t;
|
||||
@ -39,7 +47,7 @@ struct FmhaBwdTypeConfig<ck_tile::half_t>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
|
||||
struct FmhaBwdTypeConfig<FmhaBwdBf16>
|
||||
{
|
||||
using QDataType = ck_tile::bf16_t;
|
||||
using KDataType = ck_tile::bf16_t;
|
||||
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
true,
|
||||
true,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
64,
|
||||
256,
|
||||
false,
|
||||
false,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
true,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,73 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<true, false, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
false,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<2, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
true,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
true>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
false, false>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, false, false, false, false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,138 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
|
||||
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
|
||||
fmha_bwd_shape_0,
|
||||
false,
|
||||
false,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
|
||||
true,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
@ -1,80 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_fwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::bf16_t;
|
||||
|
||||
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
ck_tile::sequence<4, 1, 1>,
|
||||
fmha_warp_tile_0,
|
||||
true>;
|
||||
|
||||
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::BlockAttentionBiasEnum::ALIBI,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
-1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
|
||||
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
fmha_shape_0,
|
||||
true,
|
||||
fmha_mask_0,
|
||||
fmha_trait_0>;
|
||||
|
||||
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
|
||||
fmha_pipeline_problem_0>;
|
||||
|
||||
using fmha_epilogue_0 =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
|
||||
true, true>>;
|
||||
|
||||
using fmha_kernel_0 =
|
||||
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
|
||||
fmha_pipeline_0,
|
||||
fmha_epilogue_0>;
|
||||
|
||||
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true,
|
||||
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{
|
||||
using k_ = fmha_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
@ -1,65 +0,0 @@
|
||||
// ==========================================
|
||||
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
|
||||
// @generated
|
||||
// ==========================================
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// auto generated by generate.py
|
||||
#include <fmha_bwd.hpp>
|
||||
|
||||
using fmha_dtype_0 = ck_tile::fp16_t;
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_0 =
|
||||
ck_tile::TileFmhaBwdOGradDotOTraits<true, false, 2>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
|
||||
/* BlockSize = */ 64,
|
||||
64,
|
||||
false,
|
||||
fmha_bwd_dot_do_o_trait_0>;
|
||||
|
||||
using fmha_bwd_dot_do_o_0 =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_0 =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
|
||||
|
||||
using dot_do_o_trait_0 =
|
||||
fmha_bwd_dot_do_o_traits_<64, ck_tile::fp16_t, false, true, false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_0;
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
constexpr dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel_pt<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{s.stream_id_});
|
||||
}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
|
||||
{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_0;
|
||||
return k_::GetName();
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user