[ROCm] CK Flash Attention Backend (#143695)

Replace https://github.com/pytorch/pytorch/pull/138947 for re-import.

Replaces https://github.com/ROCm/pytorch/pull/1592

This PR contains the initial implementation of SDPA with composable_kernel backend. The CK path can be forced by simply calling torch.backends.cuda.preferred_rocm_fa_library("ck"). Similarly, you can force the incumbent aotriton implementation by passing in "aotriton" or "default". As you'd expect, not setting this option will result in aotriton to be used as the backend. In the case of CK, if pytorch deems flash attention usable, then it will use the CK path in all the same places aotriton would have been used. This PR makes no changes to the heuristics which select which attention scheme to use (i.e. flash attention vs memory efficient attention vs math etc etc). It only gets called when flash attention is both enabled (via USE_FLASH_ATTENTION) and is selected at runtime by the existing heuristics.

Files located in pytorch/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha* have been pulled from https://github.com/Dao-AILab/flash-attention courtesy of @tridao's hard work who is the co-author

NOTE: In order to use this backend, the user MUST set USE_CK_FLASH_ATTENTION=1 in their environment when they build PyTorch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143695
Approved by: https://github.com/malfet

Co-authored-by: Andy Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
This commit is contained in:
Xiaodong Wang
2025-01-03 22:01:36 +00:00
committed by PyTorch MergeBot
parent 3251171ae8
commit 0a94bb432e
1840 changed files with 249657 additions and 38 deletions

View File

@ -32,6 +32,10 @@ All contributions by Cruise LLC:
Copyright (c) 2022 Cruise LLC.
All rights reserved.
All contributions by Tri Dao:
Copyright (c) 2024 Tri Dao.
All rights reserved.
All contributions by Arm:
Copyright (c) 2021, 2023-2024 Arm Limited and/or its affiliates

View File

@ -168,9 +168,28 @@ file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
# flash_attention sources
# flash_attention hip sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
message(WARNING "Building CK for multiple archs can increase build time considerably!
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
endif()
endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
endif()
#Mem_eff attention sources
file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/*.cu")
@ -185,6 +204,7 @@ if(USE_FLASH_ATTENTION)
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
list(APPEND native_transformers_hip_hip ${flash_attention_hip_hip})
list(APPEND native_transformers_hip_hip ${flash_attention_hip_aot_hip})
list(APPEND native_transformers_src_hip_hip ${flash_attention_src_hip_hip})
endif()
@ -325,6 +345,9 @@ if(USE_ROCM)
# Next two lines are needed because TunableOp uses third-party/fmt
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
if(USE_FLASH_ATTENTION)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
endif()
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}

View File

@ -343,6 +343,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#endif
}
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
return rocm_fa_preferred_backend;
}
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
// TODO: add plumbing for hasCK for validity checking
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
#ifdef USE_ROCM
if(b == at::ROCmFABackend::Ck) {
static const bool ck_unsupported = []() {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942"
};
for (auto index: c10::irange(getNumGPUs())) {
if (!detail::getCUDAHooks().isGPUArch(index, archs)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return true;
}
}
return false;
}();
if(!ck_unsupported) rocm_fa_preferred_backend = b;
}
else {
rocm_fa_preferred_backend = b;
}
#endif
rocm_fa_preferred_backend = b;
}
bool Context::allowFP16ReductionCuBLAS() const {
return allow_fp16_reduction_cublas;
}

View File

@ -4,6 +4,7 @@
#include <ATen/CPUGeneratorImpl.h>
#include <ATen/DeviceAccelerator.h>
#include <ATen/LinalgBackend.h>
#include <ATen/ROCmFABackend.h>
#include <ATen/SDPBackend.h>
#include <ATen/core/ATenGeneral.h>
#include <ATen/core/DeprecatedTypeProperties.h>
@ -239,6 +240,9 @@ class TORCH_API Context {
at::BlasBackend blasPreferredBackend();
void setBlasPreferredBackend(at::BlasBackend);
at::ROCmFABackend getROCmFAPreferredBackend() const;
void setROCmFAPreferredBackend(at::ROCmFABackend);
// Note [Enabling Deterministic Operations]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Operations in PyTorch that normally act nondeterministically, but have an
@ -428,6 +432,10 @@ class TORCH_API Context {
#endif
? at::BlasBackend::Cublaslt
: at::BlasBackend::Cublas;
at::ROCmFABackend rocm_fa_preferred_backend =
c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
? at::ROCmFABackend::Ck
: at::ROCmFABackend::Default;
#ifdef C10_MOBILE
bool release_original_weights = true;
#else

View File

@ -0,0 +1,31 @@
#pragma once
#include <c10/util/Exception.h>
#include <ostream>
#include <string>
namespace at {
enum class ROCmFABackend : int8_t { Default, AOTriton, Ck };
inline std::string ROCmFABackendToString(at::ROCmFABackend backend) {
switch (backend) {
case ROCmFABackend::Default:
return "at::ROCmFABackend::Default";
case ROCmFABackend::AOTriton:
return "at::ROCmFABackend::AOTriton";
case ROCmFABackend::Ck:
return "at::ROCmFABackend::Ck";
default:
TORCH_CHECK(false, "Unknown ROCm flash attention backend")
}
}
inline std::ostream& operator<<(
std::ostream& stream,
at::ROCmFABackend backend) {
return stream << ROCmFABackendToString(backend);
}
} // namespace at

View File

@ -28,7 +28,7 @@
#if USE_ROCM
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <aotriton/flash.h>
#define USE_AOTRITON 1
#define USE_ROCM_ATTENTION 1
#endif
#endif
@ -219,15 +219,21 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
#if USE_ROCM_ATTENTION
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) {
// User explicitly set CK as the flash attention backend. Return true for now
// TODO: Flesh out sanity checks
return true;
} else {
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
if (debug) {
TORCH_WARN(
"Flash attention was not compiled for current AMD GPU architecture. Attempting to run on architecture ", dprops->gcnArchName);
}
return false;
}
}
#else
return false;
@ -254,7 +260,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
#if USE_ROCM_ATTENTION
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();

View File

@ -124,7 +124,7 @@ inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
{
return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
aotriton::DType::kUInt64); // AOTriton accepts unsigned int64
}
} // namespace aotriton_adapter

View File

@ -115,24 +115,18 @@ prepare_philox_arguments(float p_dropout, int64_t counter_offset) {
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
std::optional<at::Generator> gen_) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
// [ROCM specific]: must be at the beginning of the function
// Otherwise check_gpu_arch() checks cuda:0 device.
at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()};
mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
int window_size_left,
int window_size_right,
const bool return_softmax,
std::optional<at::Generator> gen_) {
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
check_gpu_arch(stream);
@ -242,7 +236,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
std::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
@ -408,7 +402,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
@ -559,7 +553,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
@ -747,7 +741,6 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
return { dq, dk, dv, softmax_d };
}
} // namespace pytorch_fmha
} // namespace pytorch_flash
#endif

View File

@ -0,0 +1,100 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include <string>
#include <ck_tile/core.hpp>
#include <ck_tile/ops/fmha.hpp>
// keep sync with BlockAttentionBiasEnum
enum class bias_enum
{
no_bias = 0,
elementwise_bias = 1,
alibi = 2,
};
struct bias_info
{
bias_enum type;
/*
* simple dispatch logic
*
* if type == elementwise_bias:
* if rank_info == 0:
* bias is 1*1*s*s
* elif rank_info == 1:
* bias is 1*h*s*s
* elif rank_info == 2:
* bias is b*h*s*s
*
* elif type == alibi:
* if rank_info == 0:
* alibi in 1*h
* elif rank_info == 1:
* alibi in b*h
*/
int rank_info;
void serialize(std::ostream& os) const
{
if(type == bias_enum::no_bias)
os << "n";
else if(type == bias_enum::elementwise_bias)
{
os << "e";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
else if(type == bias_enum::alibi)
{
os << "alibi";
if(rank_info != 0)
{
os << "[" << rank_info << "]";
}
}
}
static bias_info decode(std::string str)
{
bias_info info{bias_enum::no_bias, 0};
if(str == "0" || str == "n")
{
info.type = bias_enum::no_bias;
}
else if(str.compare(0, 1, "1") == 0 || str.compare(0, 1, "e") == 0 ||
str.compare(0, 11, "elementwise") == 0)
{
info.type = bias_enum::elementwise_bias;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
else if(str.compare(0, 1, "2") == 0 || str.compare(0, 1, "a") == 0 ||
str.compare(0, 5, "alibi") == 0)
{
info.type = bias_enum::alibi;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
std::string e = str.substr(found_0 + 1);
info.rank_info = atoi(e.c_str());
}
}
return info;
}
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
{
bi.serialize(os);
return os;
}
};

View File

@ -0,0 +1,447 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
#include <ck_tile/host/kernel_launch.hpp>
#include <ck_tile/ops/fmha.hpp>
#include <ck_tile/ops/epilogue.hpp>
#include <mask.hpp>
#include <bias.hpp>
#include <type_traits>
#include <utility>
#include <variant>
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
void* dq_acc_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t split_stride_dq_acc;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.dq_acc_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dq_acc,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dq_acc,
args.nhead_stride_dk,
args.nhead_stride_dv,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dq_acc,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.split_stride_dq_acc,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdConvertQGradKernel>
auto fmha_bwd_convert_dq_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdConvertQGradKernel::kIsGroupMode)
{
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.split_stride_dq_acc);
}
else
{ // create batch mode kernel arguments
return FmhaBwdConvertQGradKernel::MakeKargs(args.dq_acc_ptr,
args.dq_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.stride_dq,
args.stride_dq_acc,
args.nhead_stride_dq,
args.nhead_stride_dq_acc,
args.batch_stride_dq,
args.batch_stride_dq_acc,
args.split_stride_dq_acc);
}
}();
dim3 grids = FmhaBwdConvertQGradKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
typename FmhaDropout_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kIsDeterministic_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
using FmhaDropout = ck_tile::remove_cvref_t<FmhaDropout_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
bool kPadS_,
bool kPadD_,
bool kIsDeterministic_>
struct fmha_bwd_convert_dq_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
};
template <typename Traits_>
float fmha_bwd_convert_dq_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_convert_dq_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
bool is_store_randval;
bool is_deterministic;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,79 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
64,
128,
128,
false,
true,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
ck_tile::bf16_t,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
{
using k_ = fmha_bwd_convert_dq_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,79 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<true, true, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
64,
128,
128,
true,
true,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
ck_tile::bf16_t,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
{
using k_ = fmha_bwd_convert_dq_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,79 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
64,
64,
256,
false,
false,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<256,
ck_tile::fp16_t,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
{
using k_ = fmha_bwd_convert_dq_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,79 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
64,
128,
128,
false,
true,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
ck_tile::bf16_t,
false,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
{
using k_ = fmha_bwd_convert_dq_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, false,128, 64, 32, 64, 32, 64, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<128, ck_tile::fp16_t, true,128, 128, 32, 128, 32, 128, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, true,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, true, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,79 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_bwd_convert_dq_trait_0 =
ck_tile::TileFmhaBwdConvertQGradTraits<true, false, 2>;
using fmha_bwd_convert_dq_pipeline_problem_0 =
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
/* BlockSize = */ 256,
64,
128,
128,
false,
false,
fmha_bwd_convert_dq_trait_0>;
using fmha_bwd_convert_dq_0 =
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
using fmha_bwd_convert_dq_kernel_0 =
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
ck_tile::fp16_t,
false,
true,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_convert_dq_<convert_dq_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_convert_dq_kernel_0;
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
{
using k_ = fmha_bwd_convert_dq_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 64, 256, 16, 256, 16, 32, 256, 256>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<256,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
true,
true,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 32, 32, 32, 32, 64, 32, 32>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<2, 2, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<32,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, true,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, false,128, 128, 32, 256, 32, 256, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::bf16_t, false,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVR<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
true>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
true>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
true,
true,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 16, 32, 32, 32>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<2, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<32, ck_tile::fp16_t, false,128, 64, 16, 32, 32, 32, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
false,
false,
false,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::KGradDataType,
false,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::fp16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::fp16_t>::VGradDataType,
false,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::fp16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
false,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVS<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<256, ck_tile::bf16_t, true,128, 128, 32, 256, 32, 256, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
true,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_HBS<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<64, ck_tile::bf16_t, true,128, 64, 32, 64, 32, 64, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, true, true, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::fp16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
true,
true,
true,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::fp16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::fp16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<64, ck_tile::fp16_t, false,128, 64, 32, 64, 32, 64, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::NO_BIAS, true, true, false, true, true, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,84 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_fwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
using fmha_warp_tile_0 = ck_tile::sequence<32, 32, 16>;
using fmha_shape_0 = ck_tile::TileFmhaShape<fmha_block_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
ck_tile::sequence<4, 1, 1>,
fmha_warp_tile_0,
true>;
using fmha_trait_0 = ck_tile::TileFmhaTraits<true,
false,
true,
true,
ck_tile::BlockAttentionBiasEnum::ALIBI,
false,
false,
true,
false,
-1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_pipeline_problem_0 = ck_tile::BlockFmhaPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_0>::ODataType,
fmha_shape_0,
false,
fmha_mask_0,
fmha_trait_0>;
using fmha_pipeline_0 = ck_tile::BlockFmhaPipelineQRKSVSAsync<
fmha_pipeline_problem_0>;
using fmha_epilogue_0 =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<ck_tile::bf16_t>::OaccDataType,
typename FmhaFwdTypeConfig<ck_tile::bf16_t>::ODataType,
true, true>>;
using fmha_kernel_0 =
ck_tile::FmhaFwdKernel<ck_tile::FmhaFwdTilePartitioner_SHB<fmha_shape_0>,
fmha_pipeline_0,
fmha_epilogue_0>;
using trait_0 = fmha_fwd_traits_<128, ck_tile::bf16_t, false,128, 128, 32, 128, 32, 128, true,
ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC, fmha_mask_0, ck_tile::BlockAttentionBiasEnum::ALIBI, false, true, false, true, false, true, true>;
#include <iostream>
template<>
float fmha_fwd_<trait_0>(const ck_tile::stream_config& s, fmha_fwd_args a)
{
using k_ = fmha_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<true,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<true>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
true,
true,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
ck_tile::bf16_t,
true,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
true,
true,
false,
false,
true>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

View File

@ -0,0 +1,144 @@
// ==========================================
// THIS CODE IS AUTOGENERATED. DO NOT MODIFY.
// @generated
// ==========================================
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// auto generated by generate.py
#include <fmha_bwd.hpp>
using fmha_dtype_0 = ck_tile::bf16_t;
using fmha_block_tile_0 = ck_tile::
sequence<32, 128, 64, 32, 64, 32, 32, 64, 64>;
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>;
using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
// G0&G2 -> GSdP
// G1&G3 -> GdKV
// G4 -> GdQ
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps0_0,
fmha_warp_tile0_0,
fmha_block_warps1_0,
fmha_warp_tile1_0,
fmha_block_warps2_0,
fmha_warp_tile0_0>;
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
true,
false,
false,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
false,
false,
1>;
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<true, false, false>;
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::GemmDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::LSEDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::DDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::RandValOutputDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::OGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::KGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::VGradDataType,
typename FmhaBwdTypeConfig<fmha_dtype_0>::BiasGradDataType,
fmha_bwd_shape_0,
false,
false,
fmha_mask_0,
fmha_dropout_0,
fmha_bwd_trait_0>;
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::KGradDataType,
true,
false>>;
using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<ck_tile::bf16_t>::AccDataType,
typename FmhaBwdTypeConfig<ck_tile::bf16_t>::VGradDataType,
true,
false>>;
using fmha_bwd_dq_dk_dv_kernel_0 =
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
fmha_bwd_dk_epilogue_0,
fmha_bwd_dv_epilogue_0>;
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<64,
ck_tile::bf16_t,
false,
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
fmha_mask_0,
fmha_dropout_0,
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
false,
false,
true,
false,
false,
false>;
#include <iostream>
template <>
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs));
#else
return 0.0;
#endif
}
template <>
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
fmha_bwd_args a)
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
#if (defined(__gfx90a__) || defined(__gfx942__))
ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{}, grids, blocks, 0, kargs)(
ck_tile::stream_config{s.stream_id_});
#endif
}
template <>
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
{
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
return k_::GetName();
}

Some files were not shown because too many files have changed in this diff Show More