mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: # Summary ### Sticky points Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC ## Dependencies - Flash PR: https://github.com/Dao-AILab/flash-attention/pull/1419 ### Other Points - The BC linter is complaining about losing generate.py and its functions which is not real BC surface cc albanD imported-using-ghimport Test Plan: Imported from OSS Building in dev `buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a //caffe2:ATen-cu --show-full-output ` I and Nming the .so I do see that the flash symbols are correctly named: ``` 0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#2}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const 0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#6}::operator()() const 0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()#1}::operator()() const::{lambda()#1}::operator()() const::{lambda()#7}::operator()() const ``` Reviewed By: vkuzo Differential Revision: D68502879 Pulled By: drisspg Pull Request resolved: https://github.com/pytorch/pytorch/pull/146372 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
276dfe8150
commit
3ecfe6be25
@ -872,11 +872,6 @@ cmake_dependent_option(
|
||||
"USE_CUDA OR USE_ROCM;NOT MSVC"
|
||||
OFF)
|
||||
|
||||
# We are currenlty not using alibi attention for Flash So we disable this
|
||||
# feature by default We dont currently document this feature because we don't
|
||||
# Suspect users building from source will need this
|
||||
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
|
||||
|
||||
# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
|
||||
# Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
|
@ -164,9 +164,12 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
|
||||
file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
|
||||
# flash_attention sources
|
||||
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
||||
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
||||
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
|
||||
# Flash attention C++ sources
|
||||
file(GLOB flash_attention_cuda_cpp
|
||||
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
|
||||
"native/transformers/cuda/flash_attn/flash_api.cpp"
|
||||
)
|
||||
|
||||
# flash_attention hip sources
|
||||
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
|
||||
|
@ -14852,7 +14852,7 @@
|
||||
MPS: _scaled_dot_product_attention_math_mps
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
dispatch:
|
||||
CUDA: _scaled_dot_product_flash_attention_cuda
|
||||
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
|
||||
@ -14909,13 +14909,13 @@
|
||||
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _flash_attention_forward
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
|
||||
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
|
||||
device_check: NoCheck
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -70,6 +70,9 @@
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
// FlashAttention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#if !defined(__HIP_PLATFORM_AMD__)
|
||||
#include <namespace_config.h>
|
||||
#endif
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
#ifndef USE_ROCM
|
||||
@ -916,6 +919,7 @@ _flash_attention_forward(
|
||||
std::optional<Tensor> seqused_k = _seqused_k;
|
||||
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
|
||||
std::optional<Tensor> alibi_slopes = _alibi_slopes;
|
||||
const float softcap = 0.0;
|
||||
|
||||
const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
|
||||
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
|
||||
@ -939,7 +943,7 @@ _flash_attention_forward(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask) =
|
||||
pytorch_flash::mha_varlen_fwd(
|
||||
FLASH_NAMESPACE::mha_varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
@ -957,6 +961,7 @@ _flash_attention_forward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
return_debug_mask,
|
||||
std::nullopt /*gen_*/);
|
||||
} else {
|
||||
@ -969,7 +974,7 @@ _flash_attention_forward(
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask) =
|
||||
pytorch_flash::mha_fwd(
|
||||
FLASH_NAMESPACE::mha_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
@ -980,6 +985,7 @@ _flash_attention_forward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
std::nullopt);
|
||||
}
|
||||
|
@ -94,6 +94,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
|
||||
// Currently unused args:
|
||||
std::optional<at::Tensor> alibi_slopes{std::nullopt};
|
||||
const float softcap = 0.0;
|
||||
|
||||
bool determinisitic{false};
|
||||
auto& ctx = at::globalContext();
|
||||
@ -111,7 +112,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
// in order to determine whether we are using varlen or dense forward
|
||||
if (cumulative_sequence_length_q.defined()) {
|
||||
// Varlen forward
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_varlen_bwd(
|
||||
contiguous_grad_out,
|
||||
query,
|
||||
key,
|
||||
@ -132,13 +133,14 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
|
||||
} else {
|
||||
// Dense forward
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_bwd(
|
||||
contiguous_grad_out,
|
||||
query,
|
||||
key,
|
||||
@ -154,6 +156,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
softcap,
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
|
@ -1,74 +0,0 @@
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_causal>
|
||||
struct Alibi {
|
||||
|
||||
const float alibi_slope;
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
|
||||
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
||||
: alibi_slope(alibi_slope)
|
||||
, max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q) {
|
||||
};
|
||||
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,46 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Varlen=true>
|
||||
struct BlockInfo {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
, seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb]))
|
||||
, actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const int actual_seqlen_q;
|
||||
// We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0.
|
||||
const int seqlen_k_cache;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,96 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct Dropout {
|
||||
|
||||
const unsigned long long seed, offset;
|
||||
const uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
|
||||
const uint8_t p_dropout_in_uint8_t,
|
||||
const int bid, const int hid, const int tid, const int nheads)
|
||||
: seed(seed)
|
||||
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
|
||||
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
|
||||
int block_row_start, int block_col_start, int block_row_stride) {
|
||||
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
|
||||
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same_v<T, cutlass::half_t> || std::is_same_v<T, cutlass::bfloat16_t>)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,190 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
namespace pytorch_flash {
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = int64_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
void * __restrict__ oaccum_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void * __restrict__ p_ptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
void * __restrict__ softmax_lseaccum_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
// If provided, the actual length of each k sequence.
|
||||
int * __restrict__ seqused_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The K_new and V_new matrices.
|
||||
void * __restrict__ knew_ptr;
|
||||
void * __restrict__ vnew_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t knew_batch_stride;
|
||||
index_t vnew_batch_stride;
|
||||
index_t knew_row_stride;
|
||||
index_t vnew_row_stride;
|
||||
index_t knew_head_stride;
|
||||
index_t vnew_head_stride;
|
||||
|
||||
// The cos and sin matrices for rotary embedding.
|
||||
void * __restrict__ rotary_cos_ptr;
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The indices to index into the KV cache.
|
||||
int * __restrict__ cache_batch_idx;
|
||||
|
||||
// Paged KV cache
|
||||
int * __restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
// uint16_t p_dropout_in_uint16_t;
|
||||
uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Local window size
|
||||
int window_size_left, window_size_right;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
int64_t * extragraph_offset;
|
||||
int64_t * seed;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
|
||||
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
|
||||
bool is_seqlens_k_cumulative;
|
||||
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The dO and dQKV matrices.
|
||||
void *__restrict__ do_ptr;
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// To accumulate dQ
|
||||
void *__restrict__ dq_accum_ptr;
|
||||
void *__restrict__ dk_accum_ptr;
|
||||
void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||
// dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
index_t do_batch_stride;
|
||||
index_t do_row_stride;
|
||||
index_t do_head_stride;
|
||||
index_t dq_batch_stride;
|
||||
index_t dk_batch_stride;
|
||||
index_t dv_batch_stride;
|
||||
index_t dq_row_stride;
|
||||
index_t dk_row_stride;
|
||||
index_t dv_row_stride;
|
||||
index_t dq_head_stride;
|
||||
index_t dk_head_stride;
|
||||
index_t dv_head_stride;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
|
||||
bool deterministic;
|
||||
index_t dq_accum_split_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
} // namespace pytorch_flash
|
@ -2,6 +2,7 @@
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
|
||||
#include <cstdint>
|
||||
@ -9,6 +10,7 @@
|
||||
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
@ -32,13 +34,16 @@
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
|
||||
#include <flash.h>
|
||||
#include <namespace_config.h>
|
||||
#include <static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
namespace FLASH_NAMESPACE {
|
||||
|
||||
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA")
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
@ -70,7 +75,9 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool seqlenq_ngroups_swapped=false) {
|
||||
const float softcap,
|
||||
bool seqlenq_ngroups_swapped=false,
|
||||
const bool unpadded_lse=false) {
|
||||
|
||||
// Reset the parameters
|
||||
params = {};
|
||||
@ -126,8 +133,19 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
#ifdef FLASHATTENTION_DISABLE_SOFTCAP
|
||||
TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap.");
|
||||
#endif
|
||||
if (softcap > 0.0) {
|
||||
params.softcap = softmax_scale / softcap;
|
||||
params.scale_softmax = softcap;
|
||||
params.scale_softmax_log2 = softcap * M_LOG2E;
|
||||
} else{
|
||||
// Remove potential NaN
|
||||
params.softcap = 0.0;
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
}
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
@ -162,6 +180,8 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
||||
#endif
|
||||
params.unpadded_lse = unpadded_lse;
|
||||
params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped;
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
@ -195,7 +215,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool deterministic) {
|
||||
const float softcap,
|
||||
bool deterministic,
|
||||
const bool unpadded_lse) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
@ -208,7 +230,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap,
|
||||
false, // seqlenq_ngroups_swapped
|
||||
unpadded_lse);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.do_ptr = dout.data_ptr();
|
||||
@ -244,11 +269,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim>(params, stream);
|
||||
}
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_mha_fwd_splitkv_dispatch<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -357,6 +384,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
|
||||
@ -398,6 +426,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
@ -443,7 +473,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -478,7 +508,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap
|
||||
);
|
||||
|
||||
|
||||
// Keep references to these tensors to extend their lifetime
|
||||
@ -486,10 +518,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
head_size, seqlen_k, seqlen_q,
|
||||
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
at::Tensor seed_t, offset_t;
|
||||
// See [Note] BC breaking change to flash seed/offset
|
||||
auto rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
@ -499,26 +530,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
rng_state = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
@ -537,7 +551,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
q_padded = q_padded.transpose(1, 2).reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og});
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1});
|
||||
}
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -558,6 +572,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
|
||||
@ -608,6 +623,8 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
const int head_size_og = sizes[2];
|
||||
const int num_heads_k = paged_KV ? k.size(2) : k.size(1);
|
||||
|
||||
if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); }
|
||||
|
||||
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
||||
const int num_blocks = !paged_KV ? 0 : k.size(0);
|
||||
const int page_block_size = !paged_KV ? 1 : k.size(1);
|
||||
@ -671,7 +688,6 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
CHECK_DEVICE(out);
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
@ -683,7 +699,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
@ -693,7 +709,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
auto softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
@ -724,7 +740,10 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
seqlenq_ngroups_swapped);
|
||||
softcap,
|
||||
seqlenq_ngroups_swapped,
|
||||
/*unpadded_lse*/true);
|
||||
params.total_q = total_q;
|
||||
if (paged_KV) {
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
@ -741,12 +760,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
||||
}
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
// [Note] BC breaking change to flash seed/offset
|
||||
// Previously: Used separate tensors for philox_seed and philox_offset, sometimes on CPU, sometimes on CUDA
|
||||
// FlashAttention change: Now uses a single uint64_t[2] tensor on device containing both seed and offset
|
||||
// Implementation: Renamed "seed" → "rng_state" (contains both seed+offset) and "offset" → "_unused"
|
||||
auto rng_state = at::empty({2}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
auto _unused = at::empty({}, at::dtype(c10::kUInt64).device(at::kCUDA));
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(std::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
@ -754,26 +775,9 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
rng_state = at::empty({2}, at::TensorOptions().dtype(c10::kUInt64).device(at::kCUDA));
|
||||
params.rng_state = reinterpret_cast<uint64_t*>(rng_state.data_ptr());
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
@ -792,16 +796,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
|
||||
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
|
||||
softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size});
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, rng_state, _unused, p};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -822,6 +828,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
@ -883,7 +890,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -982,21 +989,17 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
softcap,
|
||||
deterministic,
|
||||
/*unpadded_lse*/false);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
params.rng_state = philox_seed.data_ptr<uint64_t>();
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
@ -1025,7 +1028,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp
|
||||
std::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
std::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
std::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
@ -1040,6 +1043,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset)
|
||||
@ -1107,7 +1111,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
@ -1162,7 +1166,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
auto softmax_d = at::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
// We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded)
|
||||
@ -1173,6 +1177,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
||||
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
||||
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
||||
// Same holds for softmax_d, since LSE is stored in unpadded format.
|
||||
if (!deterministic) {
|
||||
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
@ -1218,21 +1223,17 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
deterministic);
|
||||
softcap,
|
||||
deterministic,
|
||||
/*unpadded_lse*/true);
|
||||
params.total_q = total_q;;
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
params.rng_state = philox_seed.data_ptr<uint64_t>();
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
@ -1273,6 +1274,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
int num_splits
|
||||
) {
|
||||
@ -1385,7 +1387,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int head_size_rounded = round_multiple(head_size, 32) < 224 ? round_multiple(head_size, 32) : 256;
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
@ -1413,7 +1415,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
/*p_dropout=*/0.f,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
softcap
|
||||
);
|
||||
|
||||
at::Tensor k, v, k_padded, v_padded;
|
||||
if (k_.has_value()) {
|
||||
|
@ -1,10 +1,11 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <namespace_config.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
namespace FLASH_NAMESPACE {
|
||||
|
||||
TORCH_API
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
@ -18,6 +19,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
@ -39,6 +41,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
|
||||
@ -59,6 +62,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
@ -84,8 +88,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
} // namespace FLASH_NAMESPACE
|
||||
|
@ -1,827 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/PhiloxUtils.cuh>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/dropout.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/alibi.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_N,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_B_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
constexpr int TileShape_K = decltype(tiled_mma.template tile_size_mnk<2>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
// This gives the correct layout, idk why.
|
||||
// auto t = make_tile(Layout<Shape<Shape<_8, _2>, _2>,
|
||||
// Stride<Stride<_1, _64>, _8> >{},
|
||||
// auto t = make_tile(Layout<Shape<_8, _2, _2>,
|
||||
// Stride<_1, _64, _8> >{},
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}, // (1, 64, 8) or (1, 32, 8)
|
||||
make_layout(Int<TileShape_K>{}));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_B_warpcontiguousN "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutB_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_N,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
constexpr int TileShape_M = decltype(tiled_mma.template tile_size_mnk<0>())::value;
|
||||
constexpr int TileShape_N = decltype(tiled_mma.template tile_size_mnk<1>())::value;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_N = decltype(size<1>(AtomShape_MNK{}))::value;
|
||||
// Divide by 2 because right now we always use 2 for the ValLayout
|
||||
constexpr int kNWarpsN = TileShape_N / AtomShape_N / 2;
|
||||
constexpr int MMAStride_N = MMA_N * AtomShape_N * 2;
|
||||
auto t = make_tile(make_layout(Int<TileShape_M>{}),
|
||||
Layout<Shape<Int<AtomShape_N>, Int<kNWarpsN>, _2>, // (8, 2, 2) or (8, 4, 2)
|
||||
Stride<_1, Int<MMAStride_N>, _8> >{}); // (1, 64, 8) or (1, 32, 8)
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousN "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int MMA_N_SdP = kBlockN / decltype(typename Kernel_traits::TiledMmaSdP{}.template tile_size_mnk<1>())::value;
|
||||
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
|
||||
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
|
||||
if (Is_local) {
|
||||
m_block_max = std::min(m_block_max, cute::ceil_div((n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left, kBlockM));
|
||||
}
|
||||
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ (m_block_max - 1) * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
+ (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride);
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ (m_block_max - 1) * kBlockM;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
Tensor gdPsum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQdO{});
|
||||
Tensor sQt = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
|
||||
Tensor sQtNoSwizzle = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
|
||||
// Double buffer for sQ
|
||||
Tensor sdO = make_tensor(sQ.data() + (Double_buffer ? 2 : 1) * size(sQ), typename Kernel_traits::SmemLayoutQdO{});
|
||||
Tensor sdOt = make_tensor(sdO.data(), typename Kernel_traits::SmemLayoutQdOtransposed{});
|
||||
Tensor sdOtransposedNoSwizzle = make_tensor(sdO.data(),
|
||||
typename Kernel_traits::SmemLayoutQdOtransposedNoSwizzle{});
|
||||
Tensor sK = make_tensor(sdO.data() + size(sdO), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sKt = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposed{});
|
||||
Tensor sKtNoSwizzle = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutKtransposedNoSwizzle{});
|
||||
Tensor sdS = make_tensor(!Kernel_traits::Is_V_in_regs ? sV.data() + size(sV) : sK.data() + size(sK),
|
||||
typename Kernel_traits::SmemLayoutPdS{});
|
||||
Tensor sdSt = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
|
||||
Tensor sdStNoSwizzle = make_tensor(sdS.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||
Tensor sP = make_tensor(sdS.data() + size(sdS), typename Kernel_traits::SmemLayoutPdS{});
|
||||
Tensor sPt = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposed{});
|
||||
Tensor sPtNoSwizzle = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutPdStransposedNoSwizzle{});
|
||||
// sP and sdQ share the same memory so be careful
|
||||
Tensor sdQ = make_tensor(sP.data(), typename Kernel_traits::SmemLayoutdQ{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
using GmemTiledCopydO = std::conditional_t<
|
||||
Is_first,
|
||||
typename Kernel_traits::GmemTiledCopydO,
|
||||
typename Kernel_traits::GmemTiledCopyQKV
|
||||
>;
|
||||
GmemTiledCopydO gmem_tiled_copy_dO;
|
||||
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
||||
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
!Seq_parallel,
|
||||
typename Kernel_traits::GmemTiledCopydQaccum,
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd
|
||||
>;
|
||||
GmemLayoutAtomdQaccum gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
||||
Tensor tdOsdO = gmem_thr_copy_dO.partition_D(sdO);
|
||||
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
||||
// if (cute::thread0()) { print(tdQgdQaccum.layout()); printf("\n"); }
|
||||
// __syncthreads();
|
||||
// if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx < 64) {
|
||||
// printf("tidx = %d, tdQgdQaccum = 0x%p\n", tidx, tdQgdQaccum.data());
|
||||
// }
|
||||
|
||||
typename Kernel_traits::TiledMmaSdP tiled_mma_sdp;
|
||||
auto thr_mma_sdp = tiled_mma_sdp.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma_sdp.partition_fragment_A(sQ); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tSrK = thr_mma_sdp.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tdPrdO = thr_mma_sdp.partition_fragment_A(sdO); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tdPrV = thr_mma_sdp.partition_fragment_B(sV); // (MMA,MMA_N,MMA_K)
|
||||
|
||||
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
||||
auto thr_mma_dkv = tiled_mma_dkv.get_thread_slice(tidx);
|
||||
Tensor tdKrdSt = thr_mma_dkv.partition_fragment_A(sdStNoSwizzle); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdKrQt = thr_mma_dkv.partition_fragment_B(sQtNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
Tensor tdVrPt = thr_mma_dkv.partition_fragment_A(sPtNoSwizzle); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdVrdO = thr_mma_dkv.partition_fragment_B(sdOtransposedNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
|
||||
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
||||
auto thr_mma_dq = tiled_mma_dq.get_thread_slice(tidx);
|
||||
Tensor tdQrdS = thr_mma_dq.partition_fragment_A(sdS); // (MMA, MMA_N, MMA_N)
|
||||
Tensor tdQrKt = thr_mma_dq.partition_fragment_B(sKtNoSwizzle); // (MMA, MMA_K, MMA_N)
|
||||
|
||||
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_QdO = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_QdO = smem_tiled_copy_QdO.get_thread_slice(tidx);
|
||||
Tensor tSsQ = smem_thr_copy_QdO.partition_S(sQ);
|
||||
Tensor tdPsdO = smem_thr_copy_QdO.partition_S(sdO);
|
||||
|
||||
// auto smem_thr_copy_KV = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp).get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_KV = make_tiled_copy_B_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_KV = smem_tiled_copy_KV.get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_KV.partition_S(sK);
|
||||
// if (cute::thread(0, 0) && n_block == 0) { printf("sK layout: "); print(sK.layout()); printf("\n"); }
|
||||
// if (cute::thread(0, 0) && n_block == 0) { print(tSsK.layout()); printf("\n"); }
|
||||
Tensor tdPsV = smem_thr_copy_KV.partition_S(sV);
|
||||
|
||||
// Partition sP and sdS to match the accumulator partitioning
|
||||
// This has to be tiled_mma_sdp, not tiled_mma_dkv
|
||||
// auto smem_thr_copy_PdS = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp).get_thread_slice(tidx);
|
||||
auto smem_tiled_copy_PdS = make_tiled_copy_C_warpcontiguousN<MMA_N_SdP>(typename Kernel_traits::SmemCopyAtomPdS{}, tiled_mma_sdp);
|
||||
auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(tidx);
|
||||
Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
// if (cute::thread(0, 0) && n_block == 0) { printf("sP layout: "); print(sP.layout()); printf("\n"); }
|
||||
// if (cute::thread(0, 0) && n_block == 0) { print(tPsP.layout()); printf("\n"); }
|
||||
// if (n_block == 0 && blockIdx.x == 0 && blockIdx.y == 0 && tidx < 64) {
|
||||
// printf("tidx=%d, tPsP = 0x%p\n", tidx, tPsP.data());
|
||||
// }
|
||||
Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
auto smem_tiled_copy_PdSt = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_PdSt = smem_tiled_copy_PdSt.get_thread_slice(tidx);
|
||||
Tensor tdVsPt = smem_thr_copy_PdSt.partition_S(sPt);
|
||||
Tensor tdKsdSt = smem_thr_copy_PdSt.partition_S(sdSt);
|
||||
|
||||
auto smem_tiled_copy_QdOt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_QdOt = smem_tiled_copy_QdOt.get_thread_slice(tidx);
|
||||
Tensor tdVsdOt = smem_thr_copy_QdOt.partition_S(sdOt);
|
||||
Tensor tdKsQt = smem_thr_copy_QdOt.partition_S(sQt);
|
||||
|
||||
auto smem_tiled_copy_dS = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dS = smem_tiled_copy_dS.get_thread_slice(tidx);
|
||||
Tensor tdQsdS = smem_thr_copy_dS.partition_S(sdS);
|
||||
|
||||
auto smem_tiled_copy_Kt = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_Kt = smem_tiled_copy_Kt.get_thread_slice(tidx);
|
||||
Tensor tdQsKt = smem_thr_copy_Kt.partition_S(sKt);
|
||||
|
||||
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tQcQ = gmem_thr_copy_QKV.partition_D(cQ);
|
||||
Tensor tKVcKV = gmem_thr_copy_QKV.partition_D(cKV);
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
// Set predicates for k bounds
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||
}
|
||||
|
||||
// Prologue
|
||||
|
||||
// We'll advance gdQ and gdQaccum before the 1st read/write.
|
||||
tdQgdQ.data() = tdQgdQ.data() + kBlockM * params.dq_row_stride;
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + kBlockM * params.h * params.d_rounded;
|
||||
|
||||
int m_block = m_block_max - 1;
|
||||
int m_block_min = (!Is_causal && !Is_local)
|
||||
? 0
|
||||
: std::max(0, (n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right) / kBlockM);
|
||||
// If not local, we're guaranteed that m_block_min <= m_block:
|
||||
// We checked earlier that n_block * kBlockN < actual_seqlen_k, so in the causal case,
|
||||
// n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k < actual_seqlen_q.
|
||||
// So m_block_min <= (actual_seqlen_q - 1) / kBlockM.
|
||||
// Recall that m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM) = (actual_seqlen_q + kBlockM - 1) / kBlockM.
|
||||
// So m_block_m - 1 = (actual_seqlen_q - 1) / kBlockM.
|
||||
// We conclude that m_block_min <= m_block, so we will always have at least 1 iteration of the for loop.
|
||||
// However, if local, then this possible to have some blocks of K & V not attending to any query.
|
||||
// We might need to exit early and write 0 to dK and dV for those blocks.
|
||||
// Otherwise we get wrong result for the case where we don't enter the for loop.
|
||||
// And we might read OOB elements from gQ and gdO.
|
||||
// This also covers the case where actual_seqlen_q == 0
|
||||
if ((Is_local || !Is_even_MN) && m_block < m_block_min) {
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
clear(tdKrdK);
|
||||
clear(tdVrdV);
|
||||
Tensor cdKV = make_identity_tensor(make_shape(size<0>(gdK), size<1>(gdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (Double_buffer && m_block % 2 == 1) { // Double buffer for sQ
|
||||
tQsQ.data() = tQsQ.data() + size(sQ);
|
||||
tSsQ.data() = tSsQ.data() + size(sQ);
|
||||
tdKsQt.data() = tdKsQt.data() + size(sQ);
|
||||
}
|
||||
|
||||
if ((!Is_first && !Seq_parallel) || params.deterministic) { __syncthreads(); }
|
||||
|
||||
if (Kernel_traits::Is_V_in_regs) {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
||||
Tensor tdOrO = make_fragment_like(tdOgO);
|
||||
if (!Is_first) {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
} else {
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
|
||||
Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
Tensor taccScS = thr_mma_sdp.partition_C(caccS); // (MMA,MMA_N,MMA_N)
|
||||
static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||
// Convert to ((2, 2), MMA_N, MMA_N) then take only the row indices.
|
||||
Tensor taccScS_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
Tensor lse = make_tensor<ElementAccum>(Shape<Int<decltype(size(taccScS_row))::value>>{});
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccScS_row(mi));
|
||||
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY;
|
||||
}
|
||||
// We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero,
|
||||
// and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply
|
||||
// with V (which would be zero), we're fine. However, with ALiBi, we might modify these
|
||||
// scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0.
|
||||
|
||||
// Tensor tKrK = make_fragment_like(tKsK);
|
||||
// // cute::copy(gmem_tiled_copy_QKV, tKgK(_, _, _, 0), tKrK);
|
||||
// cute::copy(gmem_tiled_copy_QKV, tKgK, tKrK);
|
||||
// // if (cute::thread(1, 0)) { print(tKrK); }
|
||||
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
if (!Kernel_traits::Is_V_in_regs) {
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
pytorch_flash::cp_async_fence();
|
||||
|
||||
// if (cute::thread0()) { print(tdOgdO.layout()); printf("\n"); print(tdOrdO); print(tdOrO); }
|
||||
if (Is_first) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
if (Kernel_traits::Is_V_in_regs) {
|
||||
cute::cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
Tensor tdPrV_copy_view = smem_thr_copy_KV.retile_D(tdPrV);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tdPsV) == size<1>(tdPrV_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_KV, tdPsV, tdPrV_copy_view);
|
||||
}
|
||||
|
||||
const auto [seed, offset] = at::cuda::philox::unpack(params.philox_args);
|
||||
pytorch_flash::Dropout dropout(seed, offset, params.p_dropout_in_uint8_t,
|
||||
bidb, bidh, tidx, params.h);
|
||||
|
||||
clear(acc_dv);
|
||||
clear(acc_dk);
|
||||
|
||||
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
pytorch_flash::Alibi<Is_causal> alibi(alibi_slope, binfo.actual_seqlen_k, binfo.actual_seqlen_q);
|
||||
|
||||
for (; m_block >= m_block_min; --m_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
|
||||
clear(acc_s);
|
||||
cute::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
Tensor dP_sum = make_fragment_like(lse);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) { dP_sum(mi) = gdPsum(get<0>(taccScS_row(mi))); }
|
||||
|
||||
// if (cute::thread0()) { print(sK); }
|
||||
// Tensor tSrK_copy_view = smem_thr_copy_KV.retile_D(tSrK);
|
||||
// #pragma unroll
|
||||
// for (int k = 0; k < size<2>(tSrK_copy_view); ++k) {
|
||||
// cute::copy(smem_tiled_copy_KV, tSsK(_, _, k), tSrK_copy_view(_, _, k));
|
||||
// }
|
||||
// if (cute::thread0()) { print(tSrK); }
|
||||
pytorch_flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
|
||||
if (Has_alibi) {
|
||||
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
|
||||
}
|
||||
|
||||
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
|
||||
// actual_seqlen_k, because acc_s would be some finite value for those indices.
|
||||
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
|
||||
// so the result would still be correct.
|
||||
// However, it's possible that the values in acc_s are so large that they overflow
|
||||
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
|
||||
// So we need to mask out the elements beyond actual_seqlen_k.
|
||||
if (!Is_causal && !Is_local) {
|
||||
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
|
||||
pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k,
|
||||
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
|
||||
}
|
||||
} else if (Is_causal) {
|
||||
// Putting this causal masking right after acc_s is *much* slower for some reason.
|
||||
// TD [2023-08-16]: We need the 2nd condition because if seqlen_q is long and seqlen_k is short
|
||||
// (e.g., 256 and 2), the 2nd block of seqlen_q (from 128 to 255), we're not doing causal masking.
|
||||
// But we still want to mask out elements beyond actual_seqlen_k.
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k
|
||||
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
|
||||
pytorch_flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q,
|
||||
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
|
||||
AtomLayoutMS * 16);
|
||||
}
|
||||
} else if (Is_local) {
|
||||
if (m_block * kBlockM < (n_block + 1) * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k - params.window_size_right
|
||||
|| (m_block + 1) * kBlockM >= n_block * kBlockN + binfo.actual_seqlen_q - binfo.actual_seqlen_k + params.window_size_left
|
||||
|| (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k)) {
|
||||
pytorch_flash::apply_mask_local(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
|
||||
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
|
||||
binfo.actual_seqlen_q, AtomLayoutMS * 16,
|
||||
params.window_size_left, params.window_size_right);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// if (cute::thread(32, 0)) { print(scores); }
|
||||
// Compute the exponential value.
|
||||
pytorch_flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
|
||||
if constexpr (Is_dropout) {
|
||||
int warp_id = tidx / 32;
|
||||
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
|
||||
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
|
||||
static_assert(MMA_N_SdP % 2 == 0);
|
||||
int block_col_idx = n_block * (kBlockN / 32) + (warp_id / AtomLayoutMS) * (MMA_N_SdP / 2);
|
||||
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
acc_s, block_row_idx, block_col_idx, AtomLayoutMS
|
||||
);
|
||||
}
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = !Is_dropout
|
||||
? pytorch_flash::convert_type<Element>(acc_s)
|
||||
: pytorch_flash::convert_type_relu<Element>(acc_s);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_N, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_N, MMA_N) if using m16n8k8.
|
||||
Tensor tPrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<typename Kernel_traits::TiledMmaSdP>(rP.layout()));
|
||||
Tensor tPaP = smem_thr_copy_PdS.retile_S(tPrP); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_PdS, tPaP, tPsP);
|
||||
// if (cute::thread0()) { print(tPaP); }
|
||||
// __syncthreads();
|
||||
// if (cute::thread0()) { print(sP); }
|
||||
|
||||
Tensor acc_dp = partition_fragment_C(tiled_mma_sdp, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_N, MMA_N)
|
||||
CUTE_STATIC_ASSERT_V(size<0>(acc_dp) == size<0>(acc_s)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(acc_dp) == size<1>(acc_s)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<2>(acc_dp) == size<2>(acc_s)); // MMA
|
||||
|
||||
clear(acc_dp);
|
||||
// Tensor acc_dp_reshaped = make_tensor(acc_dp.data(), pytorch_flash::convert_layout_acc_rowcol(acc_dp.layout()));
|
||||
// #pragma unroll
|
||||
// for (int mi = 0; mi < size<0>(acc_dp_reshaped); ++mi) {
|
||||
// #pragma unroll
|
||||
// for (int ni = 0; ni < size<1>(acc_dp_reshaped); ++ni) {
|
||||
// acc_dp_reshaped(mi, ni) = -dP_sum(mi);
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (cute::thread0()) { print(dP_sum); }
|
||||
|
||||
pytorch_flash::gemm</*A_in_regs=*/false, /*B_in_regs=*/Kernel_traits::Is_V_in_regs>(
|
||||
acc_dp, tdPrdO, tdPrV, tdPsdO, tdPsV, tiled_mma_sdp,
|
||||
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV
|
||||
);
|
||||
|
||||
// Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
|
||||
Tensor dS = make_tensor(acc_dp.data(), scores.layout());
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * (!Is_dropout || p >= 0 ? dp - d : d);
|
||||
};
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(dS); ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(dS); ++ni) {
|
||||
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) { print(dS); }
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + (-int(kBlockM * params.h * params.d_rounded));
|
||||
if (Is_first || Seq_parallel) {
|
||||
clear(acc_dq);
|
||||
} else {
|
||||
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
|
||||
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
|
||||
make_layout(get<0>(acc_dq.layout()),
|
||||
get<2>(acc_dq.layout()),
|
||||
get<1>(acc_dq.layout())));
|
||||
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, acc_dq_reshaped);
|
||||
}
|
||||
|
||||
if (Double_buffer && m_block > m_block_min) {
|
||||
// Double buffer for sQ
|
||||
const int sQ_offset = m_block % 2 == 0 ? size(sQ) : -size(sQ);
|
||||
tQsQ.data() = tQsQ.data() + sQ_offset;
|
||||
tSsQ.data() = tSsQ.data() + sQ_offset;
|
||||
// Advance gQ
|
||||
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
Tensor dS_reshaped = make_tensor(dS.data(), acc_dp.layout());
|
||||
// Convert dS from fp32 to fp16
|
||||
Tensor tdSrdS = pytorch_flash::convert_type<Element>(dS_reshaped);
|
||||
// if (cute::thread0()) { print(tPrP); }
|
||||
Tensor tdSadS = smem_thr_copy_PdS.retile_S(tdSrdS); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS);
|
||||
__syncthreads();
|
||||
|
||||
// Layout p_l = tPrP.layout();
|
||||
// Tensor tdVrPt = make_tensor(tPrP.data(), make_layout(get<0>(p_l), get<2>(p_l), get<1>(p_l)));
|
||||
// pytorch_flash::gemm_rs(acc_dv, tdVrPt, tdVrdO, tdVsdOt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
// Tensor tdKrdSt = make_tensor(tdSrdS.data(), tdVrPt.layout());
|
||||
// pytorch_flash::gemm_rs(acc_dk, tdKrdSt, tdKrQt, tdKsQt, tiled_mma_dkv, smem_thr_copy_QdOt);
|
||||
pytorch_flash::gemm(acc_dv, tdVrPt, tdVrdO, tdVsPt, tdVsdOt, tiled_mma_dkv,
|
||||
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
|
||||
// if (cute::thread0() && n_block == 0 && m_block == 0) { print(tdVrPt); }
|
||||
// if (cute::thread0()) { print(acc_dv); }
|
||||
|
||||
__syncthreads(); // Need syncthreads since we're writing to the same sdO location
|
||||
|
||||
if (m_block > m_block_min) {
|
||||
// Advance gdO
|
||||
tdOgdO.data() = tdOgdO.data() + (-int(kBlockM * params.do_row_stride));
|
||||
if (Is_first) {
|
||||
tdOgO.data() = tdOgO.data() + (-int(kBlockM * params.o_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ);
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ);
|
||||
} else {
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
}
|
||||
|
||||
pytorch_flash::gemm(acc_dq, tdQrdS, tdQrKt, tdQsdS, tdQsKt, tiled_mma_dq,
|
||||
smem_tiled_copy_dS, smem_tiled_copy_Kt, smem_thr_copy_dS, smem_thr_copy_Kt);
|
||||
// if (cute::thread0()) { print(acc_dq); }
|
||||
|
||||
if (m_block > m_block_min) {
|
||||
gLSE.data() = gLSE.data() + (-int(kBlockM));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = gLSE(get<0>(taccScS_row(mi))); }
|
||||
gdPsum.data() = gdPsum.data() + (-int(kBlockM));
|
||||
}
|
||||
|
||||
if (!Is_last) {
|
||||
// Reshape acc_dq from (4, 1, 2) to (4, 2, 1) to write to gdQaccum
|
||||
Tensor acc_dq_reshaped = make_tensor(acc_dq.data(),
|
||||
make_layout(get<0>(acc_dq.layout()),
|
||||
get<2>(acc_dq.layout()),
|
||||
get<1>(acc_dq.layout())));
|
||||
if (!Seq_parallel) {
|
||||
cute::copy(gmem_tiled_copy_dQaccum, acc_dq_reshaped, tdQgdQaccum);
|
||||
} else {
|
||||
// if (cute::thread0()) { print(acc_dq.layout()); printf("\n"); print(acc_dq_reshaped.layout()); printf("\n"); print(tdQgdQaccum.layout()); printf("\n"); }
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { atomicAdd(&tdQgdQaccum(i), acc_dq(i)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
||||
// Convert acc_dq from fp32 to fp16
|
||||
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
|
||||
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
||||
}
|
||||
|
||||
pytorch_flash::gemm(acc_dk, tdKrdSt, tdKrQt, tdKsdSt, tdKsQt, tiled_mma_dkv,
|
||||
smem_tiled_copy_PdSt, smem_tiled_copy_QdOt, smem_thr_copy_PdSt, smem_thr_copy_QdOt);
|
||||
// if (cute::thread0()) { print(acc_dk); }
|
||||
if (Double_buffer) { // Double buffer for sQ
|
||||
tdKsQt.data() = tdKsQt.data() + (m_block % 2 == 0 ? size(sQ) : -size(sQ));
|
||||
}
|
||||
if (!Double_buffer && m_block > m_block_min) {
|
||||
__syncthreads();
|
||||
// Advance gQ
|
||||
tQgQ.data() = tQgQ.data() + (-int(kBlockM * params.q_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ);
|
||||
pytorch_flash::cp_async_fence();
|
||||
}
|
||||
|
||||
if (Is_first && m_block > m_block_min) {
|
||||
cute::copy(tdOrdO, tdOsdO);
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, gdPsum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
}
|
||||
|
||||
if (Is_last) {
|
||||
__syncthreads();
|
||||
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
||||
tdQgdQ.data() = tdQgdQ.data() + (-int(kBlockM * params.dq_row_stride));
|
||||
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tdQgdQ); ++m) {
|
||||
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
if (Is_dropout) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dv); ++i) { acc_dv(i) *= params.rp_dropout; }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dk); ++i) { acc_dk(i) *= params.scale_softmax_rp_dropout; }
|
||||
|
||||
// Convert acc_dv from fp32 to fp16
|
||||
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
|
||||
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
|
||||
|
||||
Tensor sdK = make_tensor(sK.data(), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
|
||||
// Partition sdV and sdK to match the accumulator partitioning
|
||||
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// We need syncthreads here since we're writing to the same location as sK and sV.
|
||||
// Without syncthreads, some thread might modify the location of sK while another thread
|
||||
// is reading it for dQ gemm, leading to a race condition.
|
||||
// If Is_last, there's already a __syncthreads() at the end of the loop.
|
||||
if (!Is_last) { __syncthreads(); }
|
||||
|
||||
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydKV gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
|
||||
__syncthreads();
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
||||
Tensor cdKV = make_identity_tensor(make_shape(size<0>(sdK), size<1>(sdK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int n_block_max = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
if (n_block_max == 1) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, true>(params, bidb, bidh, 0);
|
||||
} else {
|
||||
// Iterating backward from n_block_max - 1 to 0 might save 1 register
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, true, false>(params, bidb, bidh, n_block_max - 1);
|
||||
for (int n_block = n_block_max - 2; n_block > 0; n_block--) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, false>(params, bidb, bidh, n_block);
|
||||
}
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K, false, true>(params, bidb, bidh, 0);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
|
||||
// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
|
||||
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
|
||||
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace flash
|
@ -1,338 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <c10/cuda/CUDAException.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#endif
|
||||
|
||||
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
|
||||
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
|
||||
pytorch_flash::convert_dQ<Kernel_traits>(params, nsplits);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::convert_dKV<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
int gridDimx = num_n_block;
|
||||
if (params.deterministic) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
|
||||
}
|
||||
dim3 grid_n(gridDimx, params.b, params.h);
|
||||
|
||||
if (!params.deterministic) {
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else {
|
||||
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
|
||||
// a multiple of kBlockN, we'll need to apply mask in the loop.
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
#ifndef FLASHATTENTION_DISABLE_BACKWARD
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// } else {
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
|
||||
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
|
||||
if constexpr (!Is_dropout) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}; // namespace pytorch_flash
|
@ -1,377 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
||||
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
||||
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
||||
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
|
||||
// The last coordinate is the "page".
|
||||
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
|
||||
make_layout(get<0>(do_.layout()),
|
||||
get<2>(do_.layout()))));
|
||||
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
|
||||
Tensor do_fp32 = pytorch_flash::convert_type<float>(do_reshaped);
|
||||
Tensor o_fp32 = pytorch_flash::convert_type<float>(o_reshaped);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
|
||||
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
|
||||
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
||||
}
|
||||
pytorch_flash::SumOp<float> sum_op;
|
||||
dP_sum_cur = pytorch_flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
||||
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
||||
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
|
||||
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
||||
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
|
||||
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
||||
// TODO: careful, we're zeroing out dQaccum with type float4, but when
|
||||
// we do atomicAdds, we use type float. The layouts are different. Check this.
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
||||
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
||||
|
||||
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
|
||||
// Set predicates for k bounds
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
|
||||
|
||||
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
||||
Tensor tdOrO = make_fragment_like(tdOgO);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
||||
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
||||
// so that (dP - dP_sum) is on the same scale.
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
if (Clear_dQaccum) {
|
||||
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
|
||||
// do atomicAdds on.
|
||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void clear_dKVaccum(const Params ¶ms) {
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
|
||||
Tensor zero = make_fragment_like(tdKgdKaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dQ from dQaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
|
||||
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdQ{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
||||
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
||||
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
|
||||
|
||||
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
|
||||
clear(acc_dq);
|
||||
for (int s = 0; s < nsplits; ++s) {
|
||||
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
||||
// Convert acc_dq from fp32 to fp16
|
||||
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
|
||||
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
||||
__syncthreads();
|
||||
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
||||
|
||||
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
||||
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_q.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dKV(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
|
||||
+ n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdKV{});
|
||||
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
||||
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
|
||||
|
||||
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
|
||||
|
||||
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
|
||||
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dk); ++i) {
|
||||
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dv); ++i) {
|
||||
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
|
||||
}
|
||||
// Convert acc_dk from fp32 to fp16
|
||||
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
|
||||
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
|
||||
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
__syncthreads();
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
||||
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
||||
|
||||
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace flash
|
File diff suppressed because it is too large
Load Diff
@ -1,378 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContextLight.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#endif
|
||||
|
||||
#if defined(ARCH_SUPPORTS_FLASH) && defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 11 && \
|
||||
defined(__CUDACC_VER_MINOR__) && __CUDACC_VER_MINOR__ >= 8
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // Enforce constraints
|
||||
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
|
||||
static_assert(Log_max_splits >= 1);
|
||||
pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs");
|
||||
static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem");
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h);
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
if (params.num_splits > 1) {
|
||||
// We want kBlockM to be as small as possible for more parallelism.
|
||||
// With 128 threads we can load 512 elements at a time, so if headdim is divisible by 128, kBlockM = 4.
|
||||
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
||||
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 4) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 2, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 8) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 3, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 16) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 4, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 32) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 5, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 64) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 6, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 128) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 7, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, int Headdim>
|
||||
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int kBlockM = 64; // Fixed for all head dimensions
|
||||
// TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256,
|
||||
// and for headdim 192 with block size 64 x 128.
|
||||
// Also for headdim 160 with block size 64 x 128 after the rotary addition.
|
||||
constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
|
||||
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// These two are always slower
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 1st ones are good for H100, A100
|
||||
// 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
||||
// If we have N = 32, there are only 1024 elements to load at once, where each load
|
||||
// is 8 elements. This means we can only use 128 threads and not 256 threads.
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
||||
status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// 64 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 96 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
}; // namespace pytorch_flash
|
@ -1,347 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = int64_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
Tile<Int<16 * kNWarps>, _16, _16>>;
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
|
||||
using SmemLayoutVtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_16, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_8, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopyOaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomOaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
using GmemLayoutAtomRotcossin = GmemLayoutAtom;
|
||||
using GmemTiledCopyRotcossin = decltype(
|
||||
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per load
|
||||
using GmemTiledCopyRotcossinCont = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
|
||||
GmemLayoutAtomRotcossin{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
// No_double_buffer is another option to reduce smem usage, but will slow things down.
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
|
||||
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
|
||||
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_bwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
|
||||
static constexpr bool No_double_buffer = No_double_buffer_;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
|
||||
static_assert(kNWarps % AtomLayoutMSdP == 0);
|
||||
static_assert(kNWarps % AtomLayoutNdKV == 0);
|
||||
static_assert(kNWarps % AtomLayoutMdQ == 0);
|
||||
|
||||
using TiledMmaSdP = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||
Tile<Int<16 * AtomLayoutMSdP>, Int<16 * kNWarps / AtomLayoutMSdP>, _16>>;
|
||||
|
||||
using TiledMmadKV = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||
Tile<Int<16 * AtomLayoutNdKV>, Int<16 * kNWarps / AtomLayoutNdKV>, _16>>;
|
||||
|
||||
using TiledMmadQ = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||
Tile<Int<16 * AtomLayoutMdQ>, Int<16 * kNWarps / AtomLayoutMdQ>, _16>>;
|
||||
|
||||
using SmemLayoutAtomQdO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQdO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdO{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
// SmemLayoutAtomQdO{},
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutKtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
// static constexpr int kPBlockN = kBlockN;
|
||||
// Temporarily disabling this for hdim 256 on sm86 and sm89
|
||||
// static_assert(kBlockN >= 64);
|
||||
static_assert(kBlockN >= 32);
|
||||
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
|
||||
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||
static constexpr int kSwizzlePdS = 3;
|
||||
using SmemLayoutAtomPdS = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
|
||||
Stride<Int<kPBlockN>, _1>>{}));
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutPdStransposed = decltype(
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>;
|
||||
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
|
||||
// to affect speed in practice.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
AutoVectorizingCopyWithAssumedAlignment<128>
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,14 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,12 +0,0 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
} // namespace pytorch_flash
|
@ -1,109 +0,0 @@
|
||||
# This file is run to generate the kernel instantiations for the flash_attn kernels
|
||||
# They are written to several files in order to speed up compilation
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "cutlass::half_t",
|
||||
"bf16": "cutlass::bfloat16_t",
|
||||
}
|
||||
|
||||
SM = [80] # Sm80 kernels support up to
|
||||
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
KERNEL_IMPL_TEMPLATE_FWD = """
|
||||
template<>
|
||||
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """
|
||||
|
||||
template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
"""
|
||||
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """
|
||||
template<>
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Kernel:
|
||||
sm: int
|
||||
dtype: str
|
||||
head_dim: int
|
||||
direction: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
if self.direction == "fwd":
|
||||
return KERNEL_IMPL_TEMPLATE_FWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
elif self.direction == "bwd":
|
||||
return KERNEL_IMPL_TEMPLATE_BWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
else:
|
||||
return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
|
||||
|
||||
|
||||
def get_all_kernels() -> list[Kernel]:
|
||||
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
|
||||
for direction in ["fwd", "bwd", "fwd_split"]:
|
||||
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
|
||||
|
||||
|
||||
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
|
||||
prelude = """
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"\n
|
||||
"""
|
||||
launch_template_str = kernel.direction if kernel.direction != "fwd_split" else "fwd"
|
||||
include = f"#include <ATen/native/transformers/cuda/flash_attn/flash_{launch_template_str}_launch_template.h>\n"
|
||||
namespace = "namespace pytorch_flash{\n"
|
||||
namespace_end = "} // namespace pytorch_flash\n"
|
||||
(autogen_dir / kernel.filename).write_text(
|
||||
prelude + include + namespace + kernel.template + namespace_end
|
||||
)
|
||||
|
||||
|
||||
def main(output_dir: Optional[str]) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
for kernel in get_all_kernels():
|
||||
write_kernel(kernel, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate_kernels",
|
||||
description="Generate the flash_attention kernels template instantiations",
|
||||
)
|
||||
# Set an optional output directory
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="Where to generate the kernels will default to the current directory",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.output_dir)
|
@ -1,213 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Is_causal, bool Is_local, bool Has_alibi>
|
||||
struct Mask {
|
||||
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
const int window_size_left, window_size_right;
|
||||
const float alibi_slope;
|
||||
|
||||
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
|
||||
const int window_size_left, const int window_size_right,
|
||||
const float alibi_slope=0.f)
|
||||
: max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q)
|
||||
, window_size_left(window_size_left)
|
||||
, window_size_right(window_size_right)
|
||||
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
|
||||
};
|
||||
|
||||
// Causal_mask: whether this particular iteration needs causal masking
|
||||
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
|
||||
static_assert(Layout::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
|
||||
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
|
||||
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
|
||||
if constexpr (Need_masking) {
|
||||
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout()));
|
||||
// Do we need both row and column indices, or just column incides?
|
||||
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Col_idx_only) {
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// No causal, no local
|
||||
if constexpr (Has_alibi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
if constexpr (!Is_even_MN) {
|
||||
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if constexpr (Has_alibi) {
|
||||
if constexpr (Is_causal) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
} else {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
|
||||
}
|
||||
}
|
||||
if constexpr (Causal_mask) {
|
||||
if (col_idx >= col_idx_limit_right) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (Is_local) {
|
||||
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
|
||||
// Causal and Local already handles MN masking
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,152 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,186 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#define UNFUSE_FMA
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
#ifdef UNFUSE_FMA
|
||||
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
|
||||
#else
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
if (Is_first) {
|
||||
pytorch_flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
pytorch_flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
pytorch_flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
};
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,394 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ uint32_t relu2(const uint32_t x);
|
||||
|
||||
template<>
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#else
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"\t .reg .f16x2 sela;\n" \
|
||||
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
|
||||
"\t and.b32 %0, sela, %1;\n"
|
||||
"}\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
||||
template<typename T>
|
||||
__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
|
||||
|
||||
template<>
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
asm volatile("cvt.rn.relu.f16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||
return res;
|
||||
}
|
||||
|
||||
template<>
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
asm volatile("cvt.rn.relu.bf16x2.f32 %0, %1, %2;\n" : "=r"(res) : "r"(b), "r"(a));
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename Tensor1,
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
|
||||
typename ThrCopyA, typename ThrCopyB>
|
||||
__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
|
||||
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // M
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); }
|
||||
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{})); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
if (!A_in_regs) { cute::copy(smem_tiled_copy_A, tCsA(_, _, i + 1), tCrA_copy_view(_, _, i + 1)); }
|
||||
if (!B_in_regs) { cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1)); }
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
Tensor tCrB_copy_view = smem_thr_copy_B.retile_D(tCrB);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrB_copy_view)); // N
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, _0{}), tCrB_copy_view(_, _, _0{}));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<2>(tCrA); ++i) {
|
||||
if (i < size<2>(tCrA) - 1) {
|
||||
cute::copy(smem_tiled_copy_B, tCsB(_, _, i + 1), tCrB_copy_view(_, _, i + 1));
|
||||
}
|
||||
cute::gemm(tiled_mma, tCrA(_, _, i), tCrB(_, _, i), acc);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
template<typename MMA_traits, typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
|
||||
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
using value_t = typename Engine::value_type;
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(tensor_uint32); ++i) {
|
||||
tensor_uint32(i) = relu2<value_t>(tensor_uint32(i));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
|
||||
static_assert(std::is_same_v<float, From_type>);
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
// HACK: this requires tensor to be "contiguous"
|
||||
Tensor tensor_float2 = recast<float2>(tensor);
|
||||
Tensor out_uint32 = make_tensor<uint32_t>(tensor_float2.layout());
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(out_uint32); ++i) {
|
||||
out_uint32(i) = convert_relu2<To_type>(tensor_float2(i));
|
||||
}
|
||||
Tensor out = make_tensor(make_rmem_ptr<To_type>(out_uint32.data()), tensor.layout());
|
||||
#else
|
||||
Tensor out = pytorch_flash::convert_type<To_type>(tensor);
|
||||
pytorch_flash::relu_(out);
|
||||
#endif
|
||||
return out;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Blocks until all but N previous cp.async.commit_group operations have committed.
|
||||
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
|
||||
// (which is equivalent to commit_group then wait_group 0).
|
||||
// Instead we just call cp.async.wait_group 0, which is slightly faster.
|
||||
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
|
||||
template <int N>
|
||||
CUTE_HOST_DEVICE
|
||||
void cp_async_wait() {
|
||||
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
} else if (Clear_OOB_MN) {
|
||||
cute::clear(D(_, m, _));
|
||||
}
|
||||
}
|
||||
// TD [2023-04-13]: Strange that the code below can cause race condition.
|
||||
// I think it's because the copies are under an if statement.
|
||||
// if (Is_even_K) {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(tiled_copy, S(_, m, _), D(_, m, _));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, _));
|
||||
// }
|
||||
// }
|
||||
// } else { // It's slightly faster in this case if iterate over K first
|
||||
// #pragma unroll
|
||||
// for (int k = 0; k < size<2>(S); ++k) {
|
||||
// if (predicate_K(k)) {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// copy(tiled_copy, S(_, m, k), D(_, m, k));
|
||||
// } else if (Clear_OOB_MN) {
|
||||
// clear(D(_, m, k));
|
||||
// }
|
||||
// }
|
||||
// } else if (Clear_OOB_K) { // There's no case where !Clear_OOB_K && Clear_OOB_MN
|
||||
// if (Clear_OOB_MN || Is_even_MN) {
|
||||
// clear(D(_, _, k));
|
||||
// } else {
|
||||
// #pragma unroll
|
||||
// for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (!(Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN)) {
|
||||
// clear(D(_, m, k));
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K,
|
||||
const int max_MN=0, const int min_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, max_MN = %d, min_MN = %d\n", blockIdx.y, max_MN, min_MN); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
// if (threadIdx.x == 0 && blockIdx.z == 0) { printf("Inner loop, blockIdx.y = %d, m = %d\n", blockIdx.y, get<0>(identity_MN(0, m, 0))); }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || predicate_K(k)) {
|
||||
cute::copy(S(_, m, k), D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
@ -267,6 +267,7 @@ mha_fwd(
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
@ -351,6 +352,7 @@ mha_varlen_fwd(
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
@ -441,6 +443,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
@ -541,6 +544,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
const bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const float softcap,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
|
@ -980,7 +980,16 @@ elseif(USE_CUDA)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_UCC)
|
||||
endif()
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_cuda PRIVATE
|
||||
USE_FLASH_ATTENTION
|
||||
FLASHATTENTION_DISABLE_ALIBI # Disable alibi attention as it's not currently used
|
||||
FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASH_NAMESPACE=pytorch_flash
|
||||
UNFUSE_FMA # Addressing issue #121558
|
||||
)
|
||||
target_include_directories(torch_cuda PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/
|
||||
)
|
||||
endif()
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_MEM_EFF_ATTENTION)
|
||||
@ -1372,7 +1381,13 @@ if(USE_ROCM)
|
||||
${ROCM_SOURCE_DIR}/include/rccl/
|
||||
)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE
|
||||
USE_FLASH_ATTENTION
|
||||
FLASHATTENTION_DISABLE_ALIBI # Disable alibi attention as it's not currently used
|
||||
FLASHATTENTION_DISABLE_SOFTCAP
|
||||
FLASH_NAMESPACE=pytorch_flash
|
||||
UNFUSE_FMA # Addressing issue #121558
|
||||
)
|
||||
endif()
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION)
|
||||
|
@ -2873,17 +2873,17 @@
|
||||
output_differentiability: [True, False, False, False]
|
||||
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
|
||||
|
||||
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward_symint(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale)
|
||||
|
||||
- name: _scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)
|
||||
output_differentiability: [True, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_for_cpu_backward(grad, query, key, value, output, logsumexp, dropout_p, is_causal, attn_mask, scale)
|
||||
|
||||
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False]
|
||||
query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale, window_size_left, window_size_right)
|
||||
query, key, value: _flash_attention_backward_symint(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right)
|
||||
|
||||
- name: _efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)
|
||||
output_differentiability: [True, False, False, False, False, False]
|
||||
|
@ -5545,6 +5545,14 @@ def meta__scaled_dot_product_flash_attention(
|
||||
# capturing or not, but at the time of tracing we don't know if we
|
||||
# are going to use cudagraphs or not, so we return meta tensors here
|
||||
# it's possible we'll need to have some special handling in inductor for sdpa
|
||||
# See [Note] BC breaking change to flash seed/offset
|
||||
if torch.version.hip and torch.cuda.is_available():
|
||||
# Maintian old path on AMD
|
||||
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||
else:
|
||||
seed = torch.empty((2), dtype=torch.uint64, device="meta")
|
||||
offset = torch.empty((), dtype=torch.uint64, device="meta")
|
||||
|
||||
return (
|
||||
attention,
|
||||
@ -5553,8 +5561,8 @@ def meta__scaled_dot_product_flash_attention(
|
||||
None,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
seed,
|
||||
offset,
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
@ -5878,11 +5886,17 @@ def meta__flash_attention_forward(
|
||||
|
||||
# Cuda Path
|
||||
attention = torch.empty_like(query)
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
if cum_seq_q is None:
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
else:
|
||||
total_q = query.size(0)
|
||||
logsumexp = torch.empty(
|
||||
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||
)
|
||||
|
||||
if return_debug_mask:
|
||||
blocksize_c = 128 if head_dim > 64 else 256
|
||||
@ -5899,12 +5913,21 @@ def meta__flash_attention_forward(
|
||||
else:
|
||||
debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
|
||||
|
||||
# See Note [Seed and Offset]:
|
||||
# See Note [Seed and Offset]
|
||||
# See [Note] BC breaking change to flash seed/offset
|
||||
seed, offset = None, None
|
||||
if torch.version.hip and torch.cuda.is_available():
|
||||
# Maintian old path on AMD
|
||||
seed = torch.empty((), dtype=torch.long, device="meta")
|
||||
offset = torch.empty((), dtype=torch.long, device="meta")
|
||||
else:
|
||||
seed = torch.empty((2), dtype=torch.uint64, device="meta")
|
||||
offset = torch.empty((), dtype=torch.uint64, device="meta")
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
seed,
|
||||
offset,
|
||||
debug_mask,
|
||||
)
|
||||
|
||||
|
@ -28,7 +28,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_forward_only(Ate
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__embedding_bag_per_sample_weights_backward(AtenTensorHandle grad, AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, AtenTensorHandle offset2bag, int64_t mode, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_c2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t forward, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fft_r2c(AtenTensorHandle self, const int64_t* dim, int64_t dim_len_, int64_t normalization, int32_t onesided, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle philox_seed, AtenTensorHandle philox_offset, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_backward(AtenTensorHandle grad_out, AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle out, AtenTensorHandle logsumexp, AtenTensorHandle cum_seq_q, AtenTensorHandle cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, AtenTensorHandle rng_state, AtenTensorHandle unused, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__flash_attention_forward(AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, AtenTensorHandle* cum_seq_q, AtenTensorHandle* cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, int32_t is_causal, int32_t return_debug_mask, double* scale, int64_t* window_size_left, int64_t* window_size_right, AtenTensorHandle* seqused_k, AtenTensorHandle* alibi_slopes, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda__fused_moving_avg_obs_fq_helper_functional(AtenTensorHandle self, AtenTensorHandle observer_on, AtenTensorHandle fake_quant_on, AtenTensorHandle running_min, AtenTensorHandle running_max, AtenTensorHandle scale, AtenTensorHandle zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, int32_t per_row_fake_quant, int32_t symmetric_quant, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3, AtenTensorHandle* ret4, AtenTensorHandle* ret5);
|
||||
|
Reference in New Issue
Block a user