mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Update flash_attention kernel from 2.3.6 to 2.5.5 (#118935)
# Summary Updates FlashAttention kernel code from tag [2.3.6](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.3.6) to [2.5.3](https://github.com/Dao-AILab/flash-attention/releases/tag/v2.5.5). The usual changes were then re-rellod on top of the modified kernel, changing how dropout saved for backward, removing the head_dim_pad since this would make the kernel inplace mutate and that has a bad interaction with functionalization. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118935 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
d49864f6a5
commit
2e6c08a14b
@ -744,6 +744,12 @@ cmake_dependent_option(
|
||||
Will be disabled if not supported by the platform" ON
|
||||
"USE_CUDA AND NOT MSVC" OFF)
|
||||
|
||||
# We are currenlty not using alibi attention for Flash
|
||||
# So we disable this feature by default
|
||||
# We dont currently document this feature because we don't
|
||||
# Suspect users building from source will need this
|
||||
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)
|
||||
|
||||
# CAVEAT: Again, do not check USE_ROCM here
|
||||
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
|
@ -50,7 +50,6 @@
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#include <ATen/ops/scaled_dot_product_attention.h>
|
||||
#include <ATen/ops/split_native.h>
|
||||
#include <ATen/ops/narrow_native.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
@ -65,7 +64,6 @@
|
||||
#include <ATen/native/transformers/attention.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
#include <ATen/native/transformers/cuda/sdp_utils.h>
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
|
||||
@ -852,6 +850,7 @@ _flash_attention_forward(
|
||||
// of the tensor. This is useful for kv cache scenarios but for now
|
||||
// we will not support in this PR.
|
||||
c10::optional<Tensor> seqused_k = c10::nullopt;
|
||||
c10::optional<Tensor> alibi_slopes = c10::nullopt;
|
||||
|
||||
// We are going to have two paths:
|
||||
// 1. The standard MHA path for dense tensors
|
||||
@ -880,6 +879,7 @@ _flash_attention_forward(
|
||||
cumulative_sequence_length_q.value(),
|
||||
cumulative_sequence_length_k.value(),
|
||||
seqused_k, /*seqused_k*/
|
||||
alibi_slopes, /*alibi_slopes*/
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
@ -905,6 +905,7 @@ _flash_attention_forward(
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
|
@ -1,3 +1,4 @@
|
||||
#include <string_view>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
@ -41,9 +42,8 @@
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/gemm_kernel_utils.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/pytorch_utils.h>
|
||||
#endif
|
||||
namespace at {
|
||||
|
||||
namespace native {
|
||||
namespace at::native {
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
const Tensor& grad_out,
|
||||
@ -74,6 +74,21 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
// The kernel computes irregardless we will drop for this functions return
|
||||
Tensor grad_softmax;
|
||||
|
||||
// Currently unused args:
|
||||
c10::optional<at::Tensor> alibi_slopes{c10::nullopt};
|
||||
|
||||
bool determinisitic{false};
|
||||
auto& ctx = at::globalContext();
|
||||
if (ctx.deterministicAlgorithms()) {
|
||||
if (ctx.deterministicAlgorithmsWarnOnly()) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Flash Attention defaults to a non-deterministic algorithm. ",
|
||||
"To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False).");
|
||||
} else {
|
||||
determinisitic = true;
|
||||
}
|
||||
}
|
||||
|
||||
// We check the whether the cumulative_sequence_length_q is defined
|
||||
// in order to determine whether we are using varlen or dense forward
|
||||
if (cumulative_sequence_length_q.defined()) {
|
||||
@ -90,6 +105,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
dv,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
alibi_slopes,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
@ -98,6 +114,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
is_causal,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(dQuery, dKey, dValue);
|
||||
@ -113,11 +130,13 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
alibi_slopes,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
-1, /*window_size_left*/
|
||||
-1, /*window_size_right*/
|
||||
determinisitic,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
|
||||
@ -630,5 +649,4 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_e
|
||||
grad_q.transpose(1, 2), grad_k.transpose(1, 2), grad_v.transpose(1, 2), grad_bias);
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
} // namespace at::native
|
||||
|
74
aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h
Normal file
74
aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h
Normal file
@ -0,0 +1,74 @@
|
||||
#include <cmath>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_causal>
|
||||
struct Alibi {
|
||||
|
||||
const float alibi_slope;
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
|
||||
__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
|
||||
: alibi_slope(alibi_slope)
|
||||
, max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q) {
|
||||
};
|
||||
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout> &tensor,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else { // Bias depends on both row_idx and col_idx
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -24,12 +24,12 @@ struct BlockInfo {
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
__forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
__forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
|
96
aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h
Normal file
96
aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h
Normal file
@ -0,0 +1,96 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
struct Dropout {
|
||||
|
||||
const unsigned long long seed, offset;
|
||||
const uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
__forceinline__ __device__ Dropout(const unsigned long long seed, const unsigned long long offset,
|
||||
const uint8_t p_dropout_in_uint8_t,
|
||||
const int bid, const int hid, const int tid, const int nheads)
|
||||
: seed(seed)
|
||||
, offset(offset + (bid * nheads + hid) * 32 + tid % 32)
|
||||
, p_dropout_in_uint8_t(p_dropout_in_uint8_t) {
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_dropout(Tensor<Engine, Layout> &tensor_,
|
||||
int block_row_start, int block_col_start, int block_row_stride) {
|
||||
// convert shape from (4, MMA_M, MMA_N) to (8, MMA_M, MMA_N / 2)
|
||||
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_dropout(tensor_.layout()));
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -5,13 +5,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/cuda/PhiloxUtils.cuh>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh> // For at::cuda::philox::unpack
|
||||
namespace pytorch_flash {
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
@ -19,7 +21,7 @@ constexpr int D_DIM = 2;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
using index_t = int64_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
@ -96,7 +98,12 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
void * __restrict__ rotary_sin_ptr;
|
||||
|
||||
// The indices to index into the KV cache.
|
||||
int *__restrict__ cache_batch_idx;
|
||||
int * __restrict__ cache_batch_idx;
|
||||
|
||||
// Paged KV cache
|
||||
int * __restrict__ block_table;
|
||||
index_t block_table_batch_stride;
|
||||
int page_block_size;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
@ -126,6 +133,9 @@ struct Flash_fwd_params : public Qkv_params {
|
||||
bool is_rotary_interleaved;
|
||||
|
||||
int num_splits; // For split-KV version
|
||||
|
||||
void * __restrict__ alibi_slopes_ptr;
|
||||
index_t alibi_slopes_batch_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -165,6 +175,9 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
|
||||
bool deterministic;
|
||||
index_t dq_accum_split_stride;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -172,7 +185,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
template<typename T, int Headdim> void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
|
@ -1,29 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
#include <c10/core/ScalarType.h>
|
||||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
@ -50,6 +26,7 @@
|
||||
#include <ATen/ops/slice.h>
|
||||
#include <ATen/ops/narrow.h>
|
||||
#include <ATen/ops/pad.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#endif
|
||||
|
||||
|
||||
@ -93,11 +70,11 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right) {
|
||||
int window_size_right,
|
||||
bool seqlenq_ngroups_swapped=false) {
|
||||
|
||||
// Reset the parameters should be equivalent
|
||||
// Reset the parameters
|
||||
params = {};
|
||||
// memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.is_bf16 = q.dtype() == at::kBFloat16;
|
||||
|
||||
@ -121,6 +98,10 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.k_batch_stride = k.stride(0);
|
||||
params.v_batch_stride = v.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
params.q_batch_stride *= seqlen_q;
|
||||
params.o_batch_stride *= seqlen_q;
|
||||
}
|
||||
}
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
@ -159,6 +140,9 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
||||
TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout.");
|
||||
#endif
|
||||
|
||||
// Causal is the special case where window_size_right == 0 and window_size_left < 0.
|
||||
// Local is the more general case where window_size_right >= 0 or window_size_left >= 0.
|
||||
@ -169,7 +153,16 @@ void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
params.window_size_left = window_size_left;
|
||||
params.window_size_right = window_size_right;
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
||||
TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0),
|
||||
"This flash attention build does not support local attention.");
|
||||
#endif
|
||||
|
||||
params.is_seqlens_k_cumulative = true;
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32.");
|
||||
#endif
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
@ -202,7 +195,8 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
int window_size_left,
|
||||
int window_size_right) {
|
||||
int window_size_right,
|
||||
bool deterministic) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
@ -244,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
|
||||
// Softmax sum
|
||||
params.dsoftmax_sum = dsoftmax_sum_d;
|
||||
|
||||
params.deterministic = deterministic;
|
||||
}
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
} else {
|
||||
@ -300,16 +296,62 @@ inline int num_splits_heuristic(int batch_nheads_mblocks, int num_SMs, int num_n
|
||||
return 1;
|
||||
}
|
||||
|
||||
void set_params_splitkv(Flash_fwd_params ¶ms, const int batch_size,
|
||||
const int num_heads, const int head_size, const int max_seqlen_k, const int max_seqlen_q,
|
||||
const int head_size_rounded, const float p_dropout,
|
||||
const int num_splits, cudaDeviceProp *dprops, struct c10::TensorOptions opts) {
|
||||
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
||||
const int num_n_blocks = (max_seqlen_k + block_n - 1) / block_n;
|
||||
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
||||
// In any case we don't expect seqlen_q to be larger than 64 for inference.
|
||||
const int num_m_blocks = (max_seqlen_q + 64 - 1) / 64;
|
||||
params.num_splits = num_splits;
|
||||
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
||||
if (num_splits < 1) {
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
||||
}
|
||||
if (params.num_splits > 1) {
|
||||
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, max_seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
}
|
||||
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
||||
}
|
||||
}
|
||||
|
||||
void set_params_alibi(Flash_fwd_params ¶ms, c10::optional<at::Tensor> &alibi_slopes_, int batch_size, int num_heads){
|
||||
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
||||
TORCH_CHECK(!alibi_slopes_.has_value(), "This flash attention build does not support alibi.");
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
#else
|
||||
if (alibi_slopes_.has_value()) {
|
||||
auto alibi_slopes = alibi_slopes_.value();
|
||||
TORCH_CHECK(alibi_slopes.dtype() == at::kFloat, "ALiBi slopes must have dtype fp32");
|
||||
CHECK_DEVICE(alibi_slopes);
|
||||
TORCH_CHECK(alibi_slopes.stride(-1) == 1, "ALiBi slopes tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(alibi_slopes.sizes() == at::IntArrayRef({num_heads}) || alibi_slopes.sizes() == at::IntArrayRef({batch_size, num_heads}));
|
||||
params.alibi_slopes_ptr = alibi_slopes.data_ptr();
|
||||
params.alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0;
|
||||
} else {
|
||||
params.alibi_slopes_ptr = nullptr;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
@ -350,12 +392,16 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
// causal=true is the same as causal=false in this case
|
||||
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0;
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
at::Tensor temp_q = q;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
@ -369,9 +415,9 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
q_padded = temp_q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
q_padded = temp_q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
@ -423,30 +469,17 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
||||
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
||||
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
||||
// In any case we don't expect seqlen_q to be larger than 64 for inference.
|
||||
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
|
||||
params.num_splits = 1;
|
||||
if (p_dropout == 0.0f) { // SplitKV is not implemented for dropout
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
||||
if (params.num_splits > 1) {
|
||||
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
}
|
||||
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
||||
}
|
||||
|
||||
set_params_splitkv(params, batch_size, num_heads,
|
||||
head_size, seqlen_k, seqlen_q,
|
||||
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
if (p_dropout > 0.0) {
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
@ -476,6 +509,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
|
||||
}
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
@ -501,18 +536,18 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const int max_seqlen_q,
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
@ -544,17 +579,39 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
int num_heads = sizes[1];
|
||||
const int head_size_og = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
|
||||
if (max_seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
void *cu_seqlens_q_d = cu_seqlens_q.data_ptr();
|
||||
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = max_seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && p_dropout == 0.f && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
at::Tensor temp_q = q;
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
temp_q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og});
|
||||
max_seqlen_q = ngroups;
|
||||
num_heads = num_heads_k;
|
||||
cu_seqlens_q_d = nullptr;
|
||||
}
|
||||
|
||||
const int total_q = q.sizes()[0];
|
||||
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!")
|
||||
|
||||
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
@ -569,7 +626,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
}
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
q_padded = q;
|
||||
q_padded = temp_q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
|
||||
@ -619,7 +676,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q_padded, k_padded, v_padded, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k.data_ptr(),
|
||||
seqused_k.has_value() ? seqused_k.value().data_ptr() : nullptr,
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
@ -627,9 +684,16 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
seqlenq_ngroups_swapped);
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
// Only apply split-k for decoding
|
||||
set_params_splitkv(params, batch_size, num_heads,
|
||||
head_size, max_seqlen_k, max_seqlen_q,
|
||||
head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts);
|
||||
}
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
@ -664,31 +728,33 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (max_seqlen_k > 0) {
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
} else {
|
||||
// If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
out.zero_();
|
||||
softmax_lse.fill_(std::numeric_limits<float>::infinity());
|
||||
}
|
||||
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
std::array<int64_t, 4> size_before = {batch_size, max_seqlen_q, num_heads_k, head_size_og};
|
||||
std::array<int64_t, 3> size_after = {batch_size, num_heads_k * max_seqlen_q, head_size_og};
|
||||
out = out.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after);
|
||||
softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1});
|
||||
}
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
if (params.d <= 32) {
|
||||
run_mha_bwd_<elem_type, 32>(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_mha_bwd_<elem_type, 64>(params, stream, configure);
|
||||
} else if (params.d <= 96) {
|
||||
run_mha_bwd_<elem_type, 96>(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_mha_bwd_<elem_type, 128>(params, stream, configure);
|
||||
} else if (params.d <= 160) {
|
||||
run_mha_bwd_<elem_type, 160>(params, stream, configure);
|
||||
} else if (params.d <= 192) {
|
||||
run_mha_bwd_<elem_type, 192>(params, stream, configure);
|
||||
} else if (params.d <= 224) {
|
||||
run_mha_bwd_<elem_type, 224>(params, stream, configure);
|
||||
} else if (params.d <= 256) {
|
||||
run_mha_bwd_<elem_type, 256>(params, stream, configure);
|
||||
}
|
||||
HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_bwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@ -702,14 +768,19 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -756,8 +827,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
if (head_size > 192 && (head_size <= 224 || is_dropout)) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
@ -768,6 +839,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
@ -803,8 +877,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
dv = at::empty_like(v);
|
||||
}
|
||||
|
||||
// const at::Tensor& dout_padded = dout;
|
||||
|
||||
// bool loop = seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
@ -818,9 +890,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
at::Tensor dq_accum;
|
||||
at::Tensor dk_accum, dv_accum;
|
||||
if (loop) {
|
||||
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dk_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = at::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
if (!deterministic) {
|
||||
dq_accum = at::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
|
||||
dq_accum = at::zeros({nsplits, batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
@ -854,10 +931,11 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
deterministic);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
@ -872,12 +950,14 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (seqlen_q > 0) {
|
||||
launch(params, stream, /*configure=*/false);
|
||||
launch(params, stream);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk.zero_();
|
||||
dv.zero_();
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
@ -901,17 +981,24 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset)
|
||||
{
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_BACKWARD
|
||||
TORCH_CHECK(false, "This flash attention build does not support backward.");
|
||||
#endif
|
||||
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
@ -925,7 +1012,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf|| q_dtype == at::kBFloat16,
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == at::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
@ -962,8 +1049,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
if (head_size > 192 && (head_size <= 224 || is_dropout)) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
@ -974,6 +1061,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
if (window_size_left >= max_seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= max_seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||
@ -1008,11 +1098,9 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = at::empty_like(k);
|
||||
dv = at::empty_like(v);
|
||||
}
|
||||
|
||||
// const at::Tensor& dout_padded = dout;
|
||||
|
||||
// bool loop = max_seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
@ -1033,7 +1121,12 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
// cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will
|
||||
// be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally
|
||||
// allowed to do. So we won't have to do any bound checking, and performance should stay the same.
|
||||
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
if (!deterministic) {
|
||||
dq_accum = at::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
} else {
|
||||
const int nsplits = (dprops->multiProcessorCount + batch_size * num_heads - 1) / (batch_size * num_heads);
|
||||
dq_accum = at::zeros({nsplits, total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
@ -1072,10 +1165,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
window_size_left,
|
||||
window_size_right);
|
||||
window_size_right,
|
||||
deterministic);
|
||||
params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
@ -1090,7 +1184,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
if (max_seqlen_q > 0) {
|
||||
launch(params, stream);
|
||||
} else {
|
||||
// If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0.
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
@ -1103,18 +1206,20 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &kcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
const at::Tensor &vcache, // batch_size_c x seqlen_k x num_heads_k x head_size or num_blocks x page_block_size x num_heads_k x head_size if there's a block_table.
|
||||
c10::optional<const at::Tensor> &k_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &v_, // batch_size x seqlen_knew x num_heads_k x head_size
|
||||
c10::optional<const at::Tensor> &seqlens_k_, // batch_size
|
||||
c10::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
|
||||
c10::optional<const at::Tensor> &cache_batch_idx_, // indices to index into the KV cache
|
||||
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2
|
||||
int num_splits
|
||||
@ -1143,25 +1248,41 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
TORCH_CHECK(kcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(vcache.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
at::Tensor block_table;
|
||||
const bool paged_KV = block_table_.has_value();
|
||||
if (paged_KV) {
|
||||
TORCH_CHECK(!cache_batch_idx_.has_value(), "Paged KVcache does not support cache_batch_idx");
|
||||
block_table = block_table_.value();
|
||||
CHECK_DEVICE(block_table);
|
||||
TORCH_CHECK(block_table.dtype() == at::kInt, "block_table must have dtype torch.int32");
|
||||
TORCH_CHECK(block_table.stride(-1) == 1, "block_table must have contiguous last dimension");
|
||||
}
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
int seqlen_q = sizes[1];
|
||||
int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = kcache.size(1);
|
||||
|
||||
const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1);
|
||||
const int num_blocks = !paged_KV ? 0 : kcache.size(0);
|
||||
const int page_block_size = !paged_KV ? 1 : kcache.size(1);
|
||||
TORCH_CHECK(!paged_KV || page_block_size % 256 == 0, "Paged KV cache block size must be divisible by 256");
|
||||
const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size;
|
||||
const int num_heads_k = kcache.size(2);
|
||||
const int batch_size_c = kcache.size(0);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size;
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be postive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
if (seqlen_q == 1) { is_causal = false; } // causal=true is the same as causal=false in this case
|
||||
// causal=true is the same as causal=false in this case
|
||||
if (seqlen_q == 1 && !alibi_slopes_.has_value()) { is_causal = false; }
|
||||
if (is_causal) { window_size_right = 0; }
|
||||
|
||||
// Faster to transpose q from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d) in this case
|
||||
// H/t Daniel Haziza
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0;
|
||||
const int seqlenq_ngroups_swapped = seqlen_q == 1 && num_heads > num_heads_k && window_size_left < 0 && window_size_right < 0 && head_size_og % 8 == 0 && !alibi_slopes_.has_value();
|
||||
if (seqlenq_ngroups_swapped) {
|
||||
const int ngroups = num_heads / num_heads_k;
|
||||
q = q.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2);
|
||||
@ -1169,9 +1290,18 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
num_heads = num_heads_k;
|
||||
}
|
||||
|
||||
if (window_size_left >= seqlen_k) { window_size_left = -1; }
|
||||
if (window_size_right >= seqlen_k) { window_size_right = -1; }
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
if (!paged_KV) {
|
||||
CHECK_SHAPE(kcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, batch_size_c, seqlen_k, num_heads_k, head_size_og);
|
||||
} else {
|
||||
CHECK_SHAPE(kcache, num_blocks, page_block_size, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(vcache, num_blocks, page_block_size, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(block_table, batch_size, max_num_blocks_per_seq);
|
||||
}
|
||||
|
||||
at::Tensor q_padded, kcache_padded, vcache_padded;
|
||||
if (head_size_og % 8 != 0) {
|
||||
@ -1310,27 +1440,24 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
TORCH_CHECK(cache_batch_idx.scalar_type() == at::kInt, "cache_batch_idx must have dtype int32");
|
||||
params.cache_batch_idx = reinterpret_cast<int *>(cache_batch_idx.data_ptr());
|
||||
}
|
||||
// This needs to match with run_mha_fwd_splitkv_dispatch
|
||||
const int block_n = head_size <= 64 ? 256 : (head_size <= 128 ? 128 : 64);
|
||||
const int num_n_blocks = (seqlen_k + block_n - 1) / block_n;
|
||||
// Technically kBlockM = 64 only for the splitKV kernels, not the standard kernel.
|
||||
// In any case we don't expect seqlen_q to be larger than 64 for inference.
|
||||
const int num_m_blocks = (seqlen_q + 64 - 1) / 64;
|
||||
params.num_splits = num_splits;
|
||||
if (num_splits < 1) {
|
||||
params.num_splits = num_splits_heuristic(batch_size * num_heads * num_m_blocks, dprops->multiProcessorCount, num_n_blocks, 128);
|
||||
}
|
||||
TORCH_CHECK(params.num_splits <= 128, "num_splits > 128 not supported");
|
||||
if (params.num_splits > 1) {
|
||||
at::Tensor softmax_lse_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor out_accum = at::empty({params.num_splits, batch_size, num_heads, seqlen_q, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
params.softmax_lseaccum_ptr = softmax_lse_accum.data_ptr();
|
||||
params.oaccum_ptr = out_accum.data_ptr();
|
||||
|
||||
set_params_splitkv(params, batch_size, num_heads,
|
||||
head_size, seqlen_k, seqlen_q,
|
||||
head_size_rounded, /*dropout*/0.f, num_splits, dprops, opts);
|
||||
|
||||
if (paged_KV) {
|
||||
params.block_table = block_table.data_ptr<int>();
|
||||
params.block_table_batch_stride = block_table.stride(0);
|
||||
}
|
||||
params.page_block_size = page_block_size;
|
||||
|
||||
|
||||
set_params_alibi(params, alibi_slopes_, batch_size, num_heads);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value());
|
||||
// Only split kernel supports appending to KV cache, or indexing to the cache with cache_batch_idx,
|
||||
// or paged KV cache
|
||||
run_mha_fwd(params, stream, /*force_split_kernel=*/k_.has_value() || cache_batch_idx_.has_value() || paged_KV);
|
||||
|
||||
if (head_size_og % 8 != 0) {
|
||||
// out = out.index({"...", at::indexing::Slice(at::indexing::None, head_size_og)});
|
||||
@ -1352,6 +1479,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
|
||||
}
|
||||
return {out, softmax_lse};
|
||||
}
|
||||
|
||||
} // namespace pytorch_fmha
|
||||
|
||||
#endif
|
||||
|
@ -12,10 +12,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
@ -28,13 +29,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const int max_seqlen_q,
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
@ -50,11 +52,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
@ -70,14 +74,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,6 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -6,58 +8,81 @@
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_preprocess_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_BACKWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_bwd_params params)
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bool Is_causal, bool Has_alibi, bool Is_even_M, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Has_alibi, Is_even_M, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) {
|
||||
__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
|
||||
pytorch_flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
pytorch_flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
|
||||
pytorch_flash::compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params);
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) {
|
||||
pytorch_flash::convert_dQ<Kernel_traits>(params, nsplits);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
|
||||
pytorch_flash::convert_dQ<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
|
||||
__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) {
|
||||
pytorch_flash::convert_dKV<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h);
|
||||
int gridDimx = num_n_block;
|
||||
if (params.deterministic) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
gridDimx = (dprops->multiProcessorCount + params.b * params.h - 1) / (params.b * params.h);
|
||||
}
|
||||
dim3 grid_n(gridDimx, params.b, params.h);
|
||||
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
if (!params.deterministic) {
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else {
|
||||
flash_bwd_dot_do_o_kernel<false, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
}
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
|
||||
@ -66,21 +91,23 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenMNConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -91,58 +118,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream,
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params, !params.deterministic ? 1 : gridDimx);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h_k);
|
||||
flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_k as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst && IsEvenKConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
|
||||
}
|
||||
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (configure) return;
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
#ifndef FLASHATTENTION_DISABLE_BACKWARD
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream);
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -152,21 +140,21 @@ void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -177,42 +165,41 @@ void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -223,26 +210,22 @@ void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const boo
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -253,35 +236,30 @@ void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -291,17 +269,17 @@ void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -311,25 +289,25 @@ void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
@ -339,14 +317,18 @@ void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bo
|
||||
if (status_ != cudaSuccess) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
|
||||
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
|
||||
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
|
||||
if constexpr (!Is_dropout) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
}; // namespace pytorch_flash
|
||||
|
@ -0,0 +1,377 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int THREADS_PER_ROW, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void dot_do_o(Tensor<Engine0, Layout0> const &do_, Tensor<Engine0, Layout0> const &o,
|
||||
Tensor<Engine1, Layout1> &dP_sum, const int gdP_col_stride, const float scale) {
|
||||
static_assert(Layout0::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(do_.layout() == o.layout());
|
||||
// Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64)
|
||||
// The last coordinate is the "page".
|
||||
Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()),
|
||||
make_layout(get<0>(do_.layout()),
|
||||
get<2>(do_.layout()))));
|
||||
Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout());
|
||||
Tensor do_fp32 = pytorch_flash::convert_type<float>(do_reshaped);
|
||||
Tensor o_fp32 = pytorch_flash::convert_type<float>(o_reshaped);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(do_reshaped); ++mi) {
|
||||
float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0);
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(do_reshaped); ni++) {
|
||||
dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni);
|
||||
}
|
||||
pytorch_flash::SumOp<float> sum_op;
|
||||
dP_sum_cur = pytorch_flash::Allreduce<THREADS_PER_ROW>::run(dP_sum_cur, sum_op) * scale;
|
||||
if (threadIdx.x % THREADS_PER_ROW == 0) {
|
||||
dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits, typename Params>
|
||||
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride;
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM;
|
||||
|
||||
Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.do_ptr) + row_offset_do),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.do_row_stride, _1{}));
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dsoftmax_sum) + row_offset_dpsum),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO;
|
||||
auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx);
|
||||
// TODO: careful, we're zeroing out dQaccum with type float4, but when
|
||||
// we do atomicAdds, we use type float. The layouts are different. Check this.
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO);
|
||||
Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum);
|
||||
|
||||
Tensor cdO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO);
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tdOpdO = make_tensor<bool>(make_shape(size<2>(tdOgdO)));
|
||||
// Set predicates for k bounds
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;}
|
||||
|
||||
Tensor tdOrdO = make_fragment_like(tdOgdO);
|
||||
Tensor tdOrO = make_fragment_like(tdOgO);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
// By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final
|
||||
// results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here,
|
||||
// so that (dP - dP_sum) is on the same scale.
|
||||
dot_do_o<Kernel_traits::kGmemThreadsPerRow>(tdOrdO, tdOrO, dP_sum,
|
||||
Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout);
|
||||
if (Clear_dQaccum) {
|
||||
// We're actually not zero'ing out all of dQaccum, but only the part that we're going to
|
||||
// do atomicAdds on.
|
||||
Tensor zero = make_fragment_like(tdQgdQaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void clear_dKVaccum(const Params ¶ms) {
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum);
|
||||
Tensor zero = make_fragment_like(tdKgdKaccum);
|
||||
clear(zero);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dQ from dQaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dQ(const Params ¶ms, const int nsplits) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride;
|
||||
const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb)
|
||||
+ (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded;
|
||||
|
||||
Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dq_ptr) + row_offset_dq),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.dq_row_stride, _1{}));
|
||||
Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dq_accum_ptr) + row_offset_dq_accum),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.h * params.d_rounded, _1{}));
|
||||
|
||||
Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdQ{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ;
|
||||
auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum;
|
||||
auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadQ tiled_mma_dq;
|
||||
auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq);
|
||||
auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx);
|
||||
Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ);
|
||||
Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum);
|
||||
|
||||
Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQgdQaccum));
|
||||
|
||||
Tensor tdQrdQaccum = make_fragment_like(tdQgdQaccum);
|
||||
clear(acc_dq);
|
||||
for (int s = 0; s < nsplits; ++s) {
|
||||
cute::copy(gmem_tiled_copy_dQaccum, tdQgdQaccum, tdQrdQaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) += tdQrdQaccum(i); }
|
||||
tdQgdQaccum.data() = tdQgdQaccum.data() + params.dq_accum_split_stride;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; }
|
||||
// Convert acc_dq from fp32 to fp16
|
||||
Tensor rdQ = pytorch_flash::convert_type<Element>(acc_dq);
|
||||
Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ);
|
||||
__syncthreads();
|
||||
Tensor tdQrdQ = make_tensor<Element>(shape(tdQgdQ));
|
||||
cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ);
|
||||
|
||||
Tensor cdQ = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
|
||||
Tensor tdQpdQ = make_tensor<bool>(make_shape(size<2>(tdQgdQ)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_q.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void convert_dKV(const Params ¶ms) {
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
const int n_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
|
||||
const BlockInfo binfo(params, bidb);
|
||||
if (n_block * kBlockN >= binfo.actual_seqlen_k) return;
|
||||
|
||||
const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride;
|
||||
const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb)
|
||||
+ n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride;
|
||||
const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded
|
||||
+ n_block * kBlockN) * params.d_rounded;
|
||||
|
||||
Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dk_ptr) + row_offset_dk),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dk_row_stride, _1{}));
|
||||
Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.dv_ptr) + row_offset_dv),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.dv_row_stride, _1{}));
|
||||
Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dk_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.dv_accum_ptr) + row_offset_dkv_accum),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
Stride<Int<kHeadDim>, _1>{});
|
||||
|
||||
Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutdKV{});
|
||||
Tensor sdV = make_tensor(sdK.data() + size(sdK), typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K)
|
||||
|
||||
typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV;
|
||||
auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum;
|
||||
auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx);
|
||||
|
||||
typename Kernel_traits::TiledMmadKV tiled_mma_dkv;
|
||||
auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv);
|
||||
auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx);
|
||||
Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK);
|
||||
Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV);
|
||||
Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum);
|
||||
Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum);
|
||||
|
||||
Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape<Int<kBlockN>, Int<kHeadDim>>{}); // MMA, MMA_N, MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum));
|
||||
CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum));
|
||||
|
||||
Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum);
|
||||
Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum);
|
||||
cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dk); ++i) {
|
||||
acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(acc_dv); ++i) {
|
||||
acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout;
|
||||
}
|
||||
// Convert acc_dk from fp32 to fp16
|
||||
Tensor rdK = pytorch_flash::convert_type<Element>(acc_dk);
|
||||
Tensor rdV = pytorch_flash::convert_type<Element>(acc_dv);
|
||||
Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N)
|
||||
cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK);
|
||||
cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV);
|
||||
__syncthreads();
|
||||
Tensor tdKrdK = make_tensor<Element>(shape(tdKgdK));
|
||||
Tensor tdVrdV = make_tensor<Element>(shape(tdVgdV));
|
||||
cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK);
|
||||
cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV);
|
||||
|
||||
Tensor cdKV = make_identity_tensor(Shape<Int<kBlockN>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV);
|
||||
Tensor tdKVpdKV = make_tensor<bool>(make_shape(size<2>(tdKgdK)));
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/false, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace flash
|
@ -1,23 +1,23 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/dropout.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/rotary.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
@ -25,57 +25,7 @@ using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
if (Is_first) {
|
||||
pytorch_flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
pytorch_flash::reduce_sum(scores, scores_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||
cute::copy(scores_max, scores_max_prev);
|
||||
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? scores_max(mi)
|
||||
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scores_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||
pytorch_flash::reduce_sum(scores, scores_sum_cur);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||
inline __device__ void write_softmax_to_gmem(
|
||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
|
||||
) {
|
||||
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||
Layout l = tOrP.layout();
|
||||
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
@ -93,6 +43,19 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
|
||||
auto seed_offset = at::cuda::philox::unpack(params.philox_args);
|
||||
pytorch_flash::Dropout dropout(std::get<0>(seed_offset), std::get<1>(seed_offset), params.p_dropout_in_uint8_t,
|
||||
bidb, bidh, tidx, params.h);
|
||||
|
||||
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||
// exit early and no one saves the rng state.
|
||||
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seed_offset);
|
||||
*params.extragraph_offset = std::get<1>(seed_offset);
|
||||
}
|
||||
}
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q) return;
|
||||
|
||||
@ -108,15 +71,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
// We exit early and write 0 to gO and gLSE. This also covers the case where actual_seqlen_k == 0.
|
||||
// Otherwise we might read OOB elements from gK and gV.
|
||||
if ((Is_causal || Is_local || !Is_even_MN) && n_block_max <= n_block_min) {
|
||||
// Save seed and offset for backward. If we don't have this here, the 0-th thread block might
|
||||
// exit early and no one saves the rng state.
|
||||
if (Is_dropout && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && tidx == 0) {
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seeds);
|
||||
*params.extragraph_offset = std::get<1>(seeds);
|
||||
}
|
||||
}
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
@ -191,8 +145,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
|
||||
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
@ -200,7 +152,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
@ -208,6 +159,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
||||
|
||||
Tensor tSgS = thr_mma.partition_C(gP);
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
||||
|
||||
//
|
||||
@ -228,10 +181,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
@ -274,16 +223,11 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// Prologue
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
// // copy(tQrQ, tQsQ);
|
||||
// pytorch_flash::cp_async_wait<0>();
|
||||
// __syncthreads();
|
||||
// // if (cute::thread(1, 0)) { print(tQsQ); }
|
||||
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
||||
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
||||
@ -313,17 +257,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
}
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seeds);
|
||||
*params.extragraph_offset = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
|
||||
|
||||
const float alibi_slope = !Has_alibi || params.alibi_slopes_ptr == nullptr ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
@ -360,37 +300,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print_tensor(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal && !Is_local) {
|
||||
if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||
// static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
|
||||
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
// Tensor idx_rowcol = make_tensor(taccScS.data(), pytorch_flash::convert_layout_acc_rowcol(taccScS.layout()));
|
||||
// pytorch_flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM);
|
||||
// Idk why it's get<1> and not get<0> of the stride.
|
||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||
// I can't get the stride from idx_row
|
||||
pytorch_flash::apply_mask_local</*HasWSLeft=*/Is_local>(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16
|
||||
);
|
||||
// if (cute::thread0()) { print_tensor(scores); }
|
||||
}
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
||||
);
|
||||
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -405,33 +317,31 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2)
|
||||
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
// Convert acc_s from fp32 to fp16/bf16
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
Tensor rP_drop = make_fragment_like(rP);
|
||||
cute::copy(rP, rP_drop);
|
||||
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
rP_drop, block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
cute::copy(rP_drop, tSgS);
|
||||
tSgS.data() = tSgS.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
@ -468,58 +378,37 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
pytorch_flash::apply_mask_local(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
);
|
||||
}
|
||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
mask.template apply_mask</*Causal_mask=*/false>(
|
||||
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
||||
);
|
||||
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
|
||||
int block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
int block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
Tensor rP_drop = make_fragment_like(rP);
|
||||
cute::copy(rP, rP_drop);
|
||||
dropout.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
rP_drop, block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
cute::copy(rP_drop, tSgS);
|
||||
tSgS.data() = tSgS.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
dropout.apply_dropout(rP, block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
Tensor lse = softmax.template normalize_softmax_lse<Is_dropout>(acc_o, params.scale_softmax, params.rp_dropout);
|
||||
|
||||
// Convert acc_o from fp32 to fp16/bf16
|
||||
Tensor rO = pytorch_flash::convert_type<Element>(acc_o);
|
||||
@ -585,7 +474,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
@ -673,10 +562,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
// We move K and V to the last block.
|
||||
const int bidb_cache = params.cache_batch_idx == nullptr ? bidb : params.cache_batch_idx[bidb];
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const int *block_table = params.block_table == nullptr ? nullptr : params.block_table + bidb * params.block_table_batch_stride;
|
||||
const int block_table_idx = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset = block_table == nullptr ? 0 : (n_block_max - 1) * kBlockN - block_table_idx * params.page_block_size;
|
||||
const index_t row_offset_k = block_table == nullptr
|
||||
? binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride
|
||||
: block_table[block_table_idx] * params.k_batch_stride + block_table_offset * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = block_table == nullptr
|
||||
? binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb_cache)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride
|
||||
: block_table[block_table_idx] * params.v_batch_stride + block_table_offset * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
@ -730,11 +626,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
@ -814,11 +705,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
Tensor tVgVnew = gmem_thr_copy_QKV.partition_S(gVnew); // (VCPY, VCPY_N, VCPY_K)
|
||||
|
||||
const int n_block_copy_min = std::max(n_block_min, binfo.seqlen_k_cache / kBlockN);
|
||||
auto tKgK_data = tKgK.data();
|
||||
auto tVgV_data = tVgV.data();
|
||||
for (int n_block = n_block_max - 1; n_block >= n_block_copy_min; n_block--) {
|
||||
pytorch_flash::copy_w_min_idx<Is_even_K>(
|
||||
tVgVnew, tVgV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN, binfo.seqlen_k_cache - n_block * kBlockN
|
||||
);
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
tVgVnew.data() = tVgVnew.data() + (-int(kBlockN * params.vnew_row_stride));
|
||||
if (params.rotary_dim == 0) {
|
||||
pytorch_flash::copy_w_min_idx<Is_even_K>(
|
||||
@ -844,19 +736,30 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
}
|
||||
}
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
tKgKnew.data() = tKgKnew.data() + (-int(kBlockN * params.knew_row_stride));
|
||||
if (block_table == nullptr) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
} else {
|
||||
if (n_block > n_block_copy_min) {
|
||||
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
|
||||
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
|
||||
const int table_diff = block_table[block_table_idx_next] - block_table[block_table_idx_cur];
|
||||
const int offset_diff = block_table_offset_next - block_table_offset_cur;
|
||||
tVgV.data() = tVgV.data() + table_diff * params.v_batch_stride + offset_diff * params.v_row_stride;
|
||||
tKgK.data() = tKgK.data() + table_diff * params.k_batch_stride + offset_diff * params.k_row_stride;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Need this before we can read in K again, so that we'll see the updated K values.
|
||||
__syncthreads();
|
||||
if (n_block_max > n_block_copy_min) {
|
||||
tKgK.data() = tKgK.data() + (n_block_max - n_block_copy_min) * kBlockN * params.k_row_stride;
|
||||
tVgV.data() = tVgV.data() + (n_block_max - n_block_copy_min) * kBlockN * params.v_row_stride;
|
||||
}
|
||||
tKgK.data() = tKgK_data;
|
||||
tVgV.data() = tVgV_data;
|
||||
}
|
||||
|
||||
// Read Q from gmem to smem, optionally apply rotary embedding.
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
if (!Append_KV || params.rotary_dim == 0) {
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
pytorch_flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
@ -907,6 +810,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
pytorch_flash::Softmax<2 * size<1>(acc_o)> softmax;
|
||||
|
||||
const float alibi_slope = !Has_alibi ? 0.0f : reinterpret_cast<float *>(params.alibi_slopes_ptr)[bidb * params.alibi_slopes_batch_stride + bidh] / params.scale_softmax;
|
||||
pytorch_flash::Mask<Is_causal, Is_local, Has_alibi> mask(binfo.actual_seqlen_k, binfo.actual_seqlen_q, params.window_size_left, params.window_size_right, alibi_slope);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
@ -927,7 +835,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
if (block_table == nullptr) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
} else {
|
||||
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
|
||||
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
|
||||
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
|
||||
}
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
@ -943,21 +859,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal && !Is_local) {
|
||||
if (!Is_even_MN) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
pytorch_flash::apply_mask_local(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
);
|
||||
}
|
||||
mask.template apply_mask<Is_causal, Is_even_MN>(
|
||||
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
||||
);
|
||||
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
@ -966,7 +870,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
if (block_table == nullptr) {
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
} else {
|
||||
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
|
||||
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_next =(n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
|
||||
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
|
||||
}
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
@ -975,18 +887,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
// We have key_padding_mask so we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
? softmax.template softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2)
|
||||
: softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal || Is_local || !Is_even_MN>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
// if (cute::thread0()) { print(scores_max); print(scores_sum); print(scores); }
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
// Convert acc_s from fp32 to fp16/bf16
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
if (n_masking_steps > 1 && n_block <= n_block_min) {
|
||||
@ -1002,7 +913,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
if (block_table == nullptr) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
} else {
|
||||
const int block_table_idx_cur = (n_block + 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_cur = (n_block + 1) * kBlockN - block_table_idx_cur * params.page_block_size;
|
||||
const int block_table_idx_next = n_block * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_next = n_block * kBlockN - block_table_idx_next * params.page_block_size;
|
||||
tVgV.data() = tVgV.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.v_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.v_row_stride;
|
||||
}
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
@ -1015,50 +934,38 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
__syncthreads();
|
||||
if (n_block > n_block_min) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
if (block_table == nullptr) {
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
} else {
|
||||
const int block_table_idx_cur = n_block * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_cur = n_block * kBlockN - block_table_idx_cur * params.page_block_size;
|
||||
const int block_table_idx_next = (n_block - 1) * kBlockN / params.page_block_size;
|
||||
const int block_table_offset_next = (n_block - 1) * kBlockN - block_table_idx_next * params.page_block_size;
|
||||
tKgK.data() = tKgK.data() + (block_table[block_table_idx_next] - block_table[block_table_idx_cur]) * params.k_batch_stride + (block_table_offset_next - block_table_offset_cur) * params.k_row_stride;
|
||||
}
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
if (Is_local && n_block * kBlockN < (m_block + 1) * kBlockM + binfo.actual_seqlen_k - binfo.actual_seqlen_q + params.window_size_right) {
|
||||
pytorch_flash::apply_mask_local(
|
||||
scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
binfo.actual_seqlen_q, kNWarps * 16,
|
||||
params.window_size_left, params.window_size_right
|
||||
);
|
||||
}
|
||||
softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
mask.template apply_mask</*Causal_mask=*/false>(
|
||||
acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
|
||||
);
|
||||
softmax.template softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_local>(acc_s, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(acc_s);
|
||||
// Reshape rP from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
pytorch_flash::gemm_rs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = inv_sum;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(acc_o, params.scale_softmax);
|
||||
// if (cute::thread0()) { print(lse); }
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
|
||||
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(smem_)), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
@ -1135,7 +1042,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
@ -1151,12 +1058,12 @@ inline __device__ void compute_attn(const Params ¶ms) {
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
|
||||
pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV, typename Params>
|
||||
inline __device__ void compute_attn_splitkv(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
@ -1165,7 +1072,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
|
||||
const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z;
|
||||
const int n_split_idx = Split ? blockIdx.y : 0;
|
||||
const int num_n_splits = Split ? gridDim.y : 1;
|
||||
pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
||||
pytorch_flash::compute_attn_1rowblock_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params, bidb, bidh, m_block, n_split_idx, num_n_splits);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1330,6 +1237,4 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) {
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
} // namespace pytorch_flash
|
||||
|
@ -12,27 +12,40 @@
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
|
||||
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
// Determine if the architecture supports FLASH and define a macro to handle parameter modifiers
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#define ARCH_SUPPORTS_FLASH
|
||||
#define KERNEL_PARAM_MODIFIER __grid_constant__
|
||||
#else
|
||||
#define KERNEL_PARAM_MODIFIER
|
||||
#endif
|
||||
|
||||
// Define a macro for unsupported architecture handling to centralize the error message
|
||||
#define FLASH_UNSUPPORTED_ARCH printf("FATAL: FlashAttention requires building with sm version sm80-sm90, but was built for < 8.0!");
|
||||
|
||||
// Use a macro to clean up kernel definitions
|
||||
#define DEFINE_FLASH_FORWARD_KERNEL(kernelName, ...) \
|
||||
template<typename Kernel_traits, __VA_ARGS__> \
|
||||
__global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params)
|
||||
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
static_assert(!(Is_causal && Is_local)); // Enforce constraints
|
||||
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!");
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_causal, bool Is_local, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
|
||||
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) {
|
||||
#if defined(ARCH_SUPPORTS_FLASH)
|
||||
pytorch_flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Split, Append_KV>(params);
|
||||
#else
|
||||
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 8.0!");
|
||||
FLASH_UNSUPPORTED_ARCH
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, int kBlockM, int Log_max_splits, bool Is_even_K>
|
||||
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
|
||||
DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_combine_kernel, int kBlockM, int Log_max_splits, bool Is_even_K) {
|
||||
static_assert(Log_max_splits >= 1);
|
||||
pytorch_flash::combine_attn_seqk_parallel<Kernel_traits, kBlockM, Log_max_splits, Is_even_K>(params);
|
||||
}
|
||||
@ -52,27 +65,30 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If return_softmax, set IsEvenMNConst to false to reduce number of templates
|
||||
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && !ReturnSoftmaxConst && Kernel_traits::kHeadDim <= 128, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, false, true, true, false>;
|
||||
// printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout));
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, false, Is_causal, false, true, true, false>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// int ctas_per_sm;
|
||||
// cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
// &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -90,22 +106,24 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] {
|
||||
BOOL_SWITCH(params.num_splits > 1, Split, [&] {
|
||||
BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
|
||||
// If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr.
|
||||
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
|
||||
// If Is_local, set Is_causal to false
|
||||
auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && !Append_KV && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, true, Split, Append_KV>;
|
||||
// auto kernel = &flash_fwd_splitkv_kernel<Kernel_traits, Is_causal, false, IsEvenKConst>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@ -118,7 +136,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
// If headdim is divisible by 64, then we set kBlockM = 8, etc.
|
||||
constexpr static int kBlockM = Kernel_traits::kHeadDim % 128 == 0 ? 4 : (Kernel_traits::kHeadDim % 64 == 0 ? 8 : 16);
|
||||
dim3 grid_combine((params.b * params.h * params.seqlen_q + kBlockM - 1) / kBlockM);
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
if (params.num_splits <= 2) {
|
||||
flash_fwd_splitkv_combine_kernel<Kernel_traits, kBlockM, 1, IsEvenKConst><<<grid_combine, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
} else if (params.num_splits <= 4) {
|
||||
@ -152,7 +170,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream)
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
@ -162,7 +180,7 @@ void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
@ -186,7 +204,7 @@ void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
@ -212,7 +230,7 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -249,7 +267,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
@ -277,7 +295,7 @@ void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr static int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -305,7 +323,7 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
@ -336,7 +354,7 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
C10_CUDA_CHECK(status_);
|
||||
}
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
@ -353,4 +371,4 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
});
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
}; // namespace pytorch_flash
|
||||
|
@ -1,5 +1,5 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
@ -26,7 +26,7 @@ struct Flash_kernel_traits {
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
using index_t = int64_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
@ -91,20 +91,10 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposedNoSwizzle{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
// https://github.com/ColfaxResearch/cutlass-kernels/blob/a222587e6d59b93ba704853d3946fb686d8b8892/src/fmha/fmha_forward.cu#L434
|
||||
using SmemLayoutVtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutVtransposed{}));
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -116,10 +106,8 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, Element>;
|
||||
using SmemCopyAtomOaccum = Copy_Atom<DefaultCopy, ElementAccum>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemQSize = size(SmemLayoutQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
@ -149,15 +137,6 @@ struct Flash_fwd_kernel_traits : public Base {
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
using GmemLayoutAtomOaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
@ -244,26 +223,18 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomKtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
|
||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
using SmemLayoutKtransposed = decltype(
|
||||
composition(SmemLayoutKV{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockN>>{}, GenRowMajor{})));
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutKtransposed{}));
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
// static constexpr int kPBlockN = kBlockN;
|
||||
static_assert(kBlockN >= 64);
|
||||
// Temporarily disabling this for hdim 256 on sm86 and sm89
|
||||
// static_assert(kBlockN >= 64);
|
||||
static_assert(kBlockN >= 32);
|
||||
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||
static constexpr int kPBlockN = 64;
|
||||
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
|
||||
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||
static constexpr int kSwizzlePdS = 3;
|
||||
@ -274,30 +245,15 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>;
|
||||
using SmemLayoutAtomPdStransposed = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
|
||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposed{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposedNoSwizzle{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemLayoutPdStransposed = decltype(
|
||||
composition(SmemLayoutPdS{}, make_layout(Shape<Int<kBlockN>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutPdStransposed{}));
|
||||
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomQdOtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
|
||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
using SmemLayoutQdOtransposed = decltype(
|
||||
composition(SmemLayoutQdO{}, make_layout(Shape<Int<kHeadDim>, Int<kBlockM>>{}, GenRowMajor{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(get_nonswizzle_portion(SmemLayoutQdOtransposed{}));
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
@ -317,16 +273,12 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
// Double buffer for sQ
|
||||
static constexpr int kSmemQdOSize = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3) * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = size(SmemLayoutKV{}) * 2 * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemPSize = size(SmemLayoutPdS{}) * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
@ -335,9 +287,6 @@ struct Flash_bwd_kernel_traits : public Base {
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
|
||||
+ kSmemdSSize + kSmemPSize;
|
||||
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
|
@ -1,161 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits_sm90 {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
} // namespace pytorch_flash
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -8,7 +8,7 @@
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure);
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
|
@ -27,8 +27,8 @@ template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params
|
||||
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """
|
||||
template<>
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
213
aten/src/ATen/native/transformers/cuda/flash_attn/mask.h
Normal file
213
aten/src/ATen/native/transformers/cuda/flash_attn/mask.h
Normal file
@ -0,0 +1,213 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__forceinline__ __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool Is_causal, bool Is_local, bool Has_alibi>
|
||||
struct Mask {
|
||||
|
||||
const int max_seqlen_k, max_seqlen_q;
|
||||
const int window_size_left, window_size_right;
|
||||
const float alibi_slope;
|
||||
|
||||
__forceinline__ __device__ Mask(const int max_seqlen_k, const int max_seqlen_q,
|
||||
const int window_size_left, const int window_size_right,
|
||||
const float alibi_slope=0.f)
|
||||
: max_seqlen_k(max_seqlen_k)
|
||||
, max_seqlen_q(max_seqlen_q)
|
||||
, window_size_left(window_size_left)
|
||||
, window_size_right(window_size_right)
|
||||
, alibi_slope(!Has_alibi ? 0.0 : alibi_slope) {
|
||||
};
|
||||
|
||||
// Causal_mask: whether this particular iteration needs causal masking
|
||||
template <bool Causal_mask=false, bool Is_even_MN=true, typename Engine, typename Layout>
|
||||
__forceinline__ __device__ void apply_mask(Tensor<Engine, Layout> &tensor_,
|
||||
const int col_idx_offset_,
|
||||
const int row_idx_offset,
|
||||
const int warp_row_stride) {
|
||||
static_assert(!(Causal_mask && Is_local), "Cannot be both causal and local");
|
||||
static_assert(Layout::rank == 3, "Only support 3D Tensor");
|
||||
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");
|
||||
static constexpr bool Need_masking = Has_alibi || Causal_mask || Is_local || !Is_even_MN;
|
||||
// if (cute::thread0()) { printf("Has_alibi = %d, Causal_mask=%d, Is_local=%d, Is_even_MN = %d, Need_masking = %d\n", Has_alibi, Causal_mask, Is_local, Is_even_MN, Need_masking); }
|
||||
if constexpr (Need_masking) {
|
||||
// Reshape tensor_ from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor tensor = make_tensor(tensor_.data(), pytorch_flash::convert_layout_acc_rowcol(tensor_.layout()));
|
||||
// Do we need both row and column indices, or just column incides?
|
||||
static constexpr bool Col_idx_only = !(Has_alibi && !Is_causal) && !Is_local && !Causal_mask;
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
if constexpr (Col_idx_only) {
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// No causal, no local
|
||||
if constexpr (Has_alibi) {
|
||||
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
}
|
||||
if constexpr (!Is_even_MN) {
|
||||
if (col_idx >= max_seqlen_k) { tensor(mi, make_coord(j, nj)) = -INFINITY; }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if constexpr (Has_alibi) {
|
||||
if constexpr (Is_causal) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) += alibi_slope * col_idx;
|
||||
} else {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
|
||||
|
||||
}
|
||||
}
|
||||
if constexpr (Causal_mask) {
|
||||
if (col_idx >= col_idx_limit_right) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (Is_local) {
|
||||
if (col_idx >= col_idx_limit_right || col_idx < col_idx_limit_left) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
if constexpr (!Causal_mask && !Is_local && !Is_even_MN) {
|
||||
// Causal and Local already handles MN masking
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
@ -11,7 +11,7 @@ struct ull2 {
|
||||
unsigned long long y;
|
||||
};
|
||||
|
||||
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
__forceinline__ __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
@ -21,7 +21,7 @@ inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
return *res;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||
__forceinline__ __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||
constexpr unsigned long kPhiloxSA = 0xD2511F53;
|
||||
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
@ -30,7 +30,7 @@ inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox(unsigned long long seed,
|
||||
__forceinline__ __device__ uint4 philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset) {
|
||||
constexpr unsigned long kPhilox10A = 0x9E3779B9;
|
||||
@ -51,117 +51,3 @@ inline __device__ uint4 philox(unsigned long long seed,
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
namespace {
|
||||
|
||||
class Philox {
|
||||
public:
|
||||
__device__ inline Philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset)
|
||||
: STATE(0)
|
||||
, seed_(seed)
|
||||
, offset_(offset)
|
||||
, key(reinterpret_cast<const uint2&>(seed)) {
|
||||
//key.x = (unsigned int)seed;
|
||||
//key.y = (unsigned int)(seed >> 32);
|
||||
//counter = make_uint4(0, 0, 0, 0);
|
||||
//counter.z = (unsigned int)(subsequence);
|
||||
//counter.w = (unsigned int)(subsequence >> 32);
|
||||
//STATE = 0;
|
||||
//incr_n(offset / 4);
|
||||
|
||||
// key = reinterpret_cast<const uint2&>(seed);
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset / 4;
|
||||
tmp->y = subsequence;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
__device__ inline uint4 operator()() {
|
||||
// // if (STATE == 0) {
|
||||
// uint4 counter_ = counter;
|
||||
// uint2 key_ = key;
|
||||
// // 7-round philox
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < 6; i++) {
|
||||
// counter_ = pytorch_flash::philox_single_round(counter_, key_);
|
||||
// key_.x += (kPhilox10A);
|
||||
// key_.y += (kPhilox10B);
|
||||
// }
|
||||
// // output = philox_single_round(counter_, key_);
|
||||
// uint4 output = pytorch_flash::philox_single_round(counter_, key_);
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// // }
|
||||
// incr();
|
||||
// // }
|
||||
// // return a float4 directly
|
||||
// // unsigned long ret;
|
||||
// // switch(STATE) {
|
||||
// // case 0: ret = output.x; break;
|
||||
// // case 1: ret = output.y; break;
|
||||
// // case 2: ret = output.z; break;
|
||||
// // case 3: ret = output.w; break;
|
||||
// //}
|
||||
// // STATE = (STATE + 1) % 4;
|
||||
// return output;
|
||||
return pytorch_flash::philox(seed_, offset_, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned long long offset_, seed_;
|
||||
struct ull2 {
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
// uint4 output;
|
||||
const uint2 key;
|
||||
unsigned int STATE;
|
||||
__device__ inline void incr_n(unsigned long long n) {
|
||||
unsigned int nlo = (unsigned int)(n);
|
||||
unsigned int nhi = (unsigned int)(n >> 32);
|
||||
counter.x += nlo;
|
||||
if (counter.x < nlo)
|
||||
nhi++;
|
||||
counter.y += nhi;
|
||||
if (nhi <= counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
|
||||
__device__ uint4 incr128 (uint4 ctr)
|
||||
{
|
||||
uint4 res;
|
||||
asm ("add.cc.u32 %0, %4, %8;\n\t"
|
||||
"addc.cc.u32 %1, %5, %9;\n\t"
|
||||
"addc.cc.u32 %2, %6, %10;\n\t"
|
||||
"addc.u32 %3, %7, %11;\n\t"
|
||||
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
|
||||
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
|
||||
"n"(1), "n"(0), "n"(0), "n"(0));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ inline void incr() {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
counter = incr128(counter);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
// static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
||||
|
152
aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h
Normal file
152
aten/src/ATen/native/transformers/cuda/flash_attn/rotary.h
Normal file
@ -0,0 +1,152 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
__forceinline__ __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
@ -1,34 +1,15 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
@ -39,7 +20,7 @@ using namespace cute;
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
@ -54,7 +35,7 @@ __device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Te
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
@ -63,26 +44,26 @@ __device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Eng
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
reduce_(tensor, sum, sum_op);
|
||||
thread_reduce_<zero_init>(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
__forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
@ -104,7 +85,7 @@ inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
@ -134,171 +115,67 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const int max_seqlen_k,
|
||||
const int col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool HasWSLeft=true, typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_local(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset_,
|
||||
const int max_seqlen_q, const int warp_row_stride,
|
||||
const int window_size_left, const int window_size_right) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const int lane_id = threadIdx.x % 32;
|
||||
// const int row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const int row_idx_offset = row_idx_offset_;
|
||||
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const int row_idx = row_idx_base + i * 8;
|
||||
const int col_idx_limit_left = std::max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left);
|
||||
const int col_idx_limit_right = std::min(max_seqlen_k, row_idx + 1 + max_seqlen_k - max_seqlen_q + window_size_right);
|
||||
template <int kNRows>
|
||||
struct Softmax {
|
||||
|
||||
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
|
||||
TensorT row_max, row_sum;
|
||||
|
||||
__forceinline__ __device__ Softmax() {};
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1>
|
||||
__forceinline__ __device__ void softmax_rescale_o(Tensor0 &acc_s, Tensor1 &acc_o, float softmax_scale_log2) {
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
static_assert(decltype(size<0>(scores))::value == kNRows);
|
||||
if (Is_first) {
|
||||
pytorch_flash::template reduce_max</*zero_init=*/true>(scores, row_max);
|
||||
pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
pytorch_flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(row_max);
|
||||
cute::copy(row_max, scores_max_prev);
|
||||
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, row_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const int col_idx_base = col_idx_offset + nj * 8;
|
||||
for (int mi = 0; mi < size(row_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? row_max(mi)
|
||||
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
row_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const int col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit_right || (HasWSLeft && col_idx < col_idx_limit_left)) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
pytorch_flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
|
||||
// We don't do the reduce across threads here since we don't need to use the row_sum.
|
||||
// We do that reduce at the end when we need to normalize the softmax.
|
||||
pytorch_flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const int col_idx_offset_,
|
||||
const int max_seqlen_k, const int row_idx_offset_,
|
||||
const int max_seqlen_q, const int warp_row_stride) {
|
||||
// Causal masking is equivalent to local masking with window_size_left = infinity and window_size_right = 0
|
||||
apply_mask_local</*HasWSLeft=*/false>(tensor, col_idx_offset_, max_seqlen_k, row_idx_offset_,
|
||||
max_seqlen_q, warp_row_stride, -1, 0);
|
||||
}
|
||||
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset_)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const int col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
int block_row_start, int block_col_start,
|
||||
int block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
|
||||
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
|
||||
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
|
||||
SumOp<float> sum_op;
|
||||
quad_allreduce_(row_sum, row_sum, sum_op);
|
||||
TensorT lse = make_fragment_like(row_sum);
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = row_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
}
|
||||
}
|
||||
return lse;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace pytorch_flash
|
||||
|
@ -14,6 +14,7 @@
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
@ -25,6 +26,46 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_DROPOUT
|
||||
#define DROPOUT_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define DROPOUT_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_ALIBI
|
||||
#define ALIBI_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define ALIBI_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_UNEVEN_K
|
||||
#define EVENK_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define EVENK_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#ifdef FLASHATTENTION_DISABLE_LOCAL
|
||||
#define LOCAL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
}()
|
||||
#else
|
||||
#define LOCAL_SWITCH BOOL_SWITCH
|
||||
#endif
|
||||
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
@ -36,7 +77,7 @@
|
||||
} \
|
||||
}()
|
||||
|
||||
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
#define HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
[&] { \
|
||||
if (HEADDIM <= 32) { \
|
||||
constexpr static int kHeadDim = 32; \
|
||||
|
@ -22,16 +22,17 @@
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t relu2(const uint32_t x);
|
||||
__forceinline__ __device__ uint32_t relu2(const uint32_t x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
@ -49,7 +50,7 @@ inline __device__ uint32_t relu2<cutlass::half_t>(const uint32_t x) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
__forceinline__ __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile("max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
@ -62,10 +63,10 @@ inline __device__ uint32_t relu2<cutlass::bfloat16_t>(const uint32_t x) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t convert_relu2(const float2 x);
|
||||
__forceinline__ __device__ uint32_t convert_relu2(const float2 x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
@ -74,7 +75,7 @@ inline __device__ uint32_t convert_relu2<cutlass::half_t>(const float2 x) {
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
__forceinline__ __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
uint32_t res;
|
||||
const uint32_t a = reinterpret_cast<const uint32_t&>(x.x);
|
||||
const uint32_t b = reinterpret_cast<const uint32_t&>(x.y);
|
||||
@ -88,20 +89,20 @@ inline __device__ uint32_t convert_relu2<cutlass::bfloat16_t>(const float2 x) {
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -110,7 +111,7 @@ template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
@ -122,7 +123,7 @@ struct Allreduce {
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
static __device__ __forceinline__ T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
@ -134,7 +135,7 @@ template<bool A_in_regs=false, bool B_in_regs=false, typename Tensor0, typename
|
||||
typename Tensor2, typename Tensor3, typename Tensor4,
|
||||
typename TiledMma, typename TiledCopyA, typename TiledCopyB,
|
||||
typename ThrCopyA, typename ThrCopyB>
|
||||
inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
__forceinline__ __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsA,
|
||||
Tensor4 const& tCsB, TiledMma tiled_mma,
|
||||
TiledCopyA smem_tiled_copy_A, TiledCopyB smem_tiled_copy_B,
|
||||
ThrCopyA smem_thr_copy_A, ThrCopyB smem_thr_copy_B) {
|
||||
@ -161,9 +162,9 @@ inline __device__ void gemm(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3
|
||||
|
||||
template<typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3,
|
||||
typename TiledMma, typename TiledCopy, typename ThrCopy>
|
||||
inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
__forceinline__ __device__ void gemm_rs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB, Tensor3 const& tCsB,
|
||||
TiledMma tiled_mma, TiledCopy smem_tiled_copy_B,
|
||||
ThrCopy smem_thr_copy_B) {
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(acc)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tCrB) == size<2>(acc)); // MMA_N
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tCrA) == size<2>(tCrB)); // MMA_K
|
||||
@ -183,42 +184,48 @@ inline __device__ void gemm_A_in_regs(Tensor0 &acc, Tensor1 &tCrA, Tensor2 &tCrB
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
template<typename Layout>
|
||||
inline __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) {
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
|
||||
// TD [2023-08-13]: Idk why but get<0, 1>(l) doesn't work for Cutlass 3.2, I'm getting
|
||||
// "int_tuple.hpp(74): error: conversion to inaccessible base class"
|
||||
// return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
return make_layout(make_layout(get<1>(get<0>(l)), get<1>(l)), make_layout(get<0>(get<0>(l)), get<2>(l)));
|
||||
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert rowcol_layout from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
|
||||
template<typename MMA_traits, typename Layout>
|
||||
inline __device__ auto convert_layout_rowcol_Aregs(Layout rowcol_layout) {
|
||||
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<1, 0>(rowcol_layout))::value == 2);
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{});
|
||||
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
|
||||
constexpr int MMA_N_divisor = mma_shape_K == 8 ? 1 : 2;
|
||||
auto l = logical_divide(rowcol_layout, Shape<X, Shape<X, Int<MMA_N_divisor>>>{}); // ((2, MMA_M), (2, (2, MMA_N / 2)))
|
||||
// TD [2023-08-13]: Same error as above on Cutlass 3.2
|
||||
// return make_layout(make_layout(get<1, 0>(l), get<0, 0>(l), get<1, 1, 0>(l)),
|
||||
// get<0, 1>(l),
|
||||
// get<1, 1, 1>(l));
|
||||
return make_layout(make_layout(get<0>(get<1>(l)), get<0>(get<0>(l)), get<0>(get<1>(get<1>(l)))),
|
||||
get<1>(get<0>(l)),
|
||||
get<1>(get<1>(get<1>(l))));
|
||||
if constexpr (mma_shape_K == 8) {
|
||||
return acc_layout;
|
||||
} else {
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
|
||||
template<typename Layout>
|
||||
__forceinline__ __device__ auto convert_layout_acc_dropout(Layout acc_layout) {
|
||||
using X = Underscore;
|
||||
static_assert(decltype(size<0>(acc_layout))::value == 4);
|
||||
static_assert(decltype(rank(acc_layout))::value == 3);
|
||||
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
|
||||
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
|
||||
@ -230,7 +237,7 @@ inline __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
__forceinline__ __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
constexpr int numel = decltype(size(tensor))::value;
|
||||
static_assert(numel % 2 == 0);
|
||||
using value_t = typename Engine::value_type;
|
||||
@ -246,7 +253,7 @@ inline __device__ void relu_(Tensor<Engine, Layout> &tensor) {
|
||||
|
||||
// On SM80 and above, we can fuse fp32 -> fp16/bf16 conversion and relu into 1 instruction
|
||||
template <typename To_type, typename Engine, typename Layout>
|
||||
inline __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
__forceinline__ __device__ auto convert_type_relu(Tensor<Engine, Layout> const &tensor) {
|
||||
using From_type = typename Engine::value_type;
|
||||
static_assert(std::is_same_v<To_type, cutlass::half_t> || std::is_same_v<To_type, cutlass::bfloat16_t>);
|
||||
static_assert(std::is_same_v<float, From_type>);
|
||||
@ -288,7 +295,7 @@ void cp_async_wait() {
|
||||
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
|
||||
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
@ -357,7 +364,7 @@ inline __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const
|
||||
template <bool Is_even_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
__forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
|
||||
Tensor<Engine3, Layout3> const &predicate_K,
|
||||
const int max_MN=0, const int min_MN=0) {
|
||||
@ -384,137 +391,4 @@ inline __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S,
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_rotary_interleaved(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin)); // MMA_K
|
||||
static_assert(decltype(size<0>(S))::value == decltype(size<0>(Cos))::value * 2);
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
cute::copy(Cos(_, m, k), rCos(_, m, k));
|
||||
cute::copy(Sin(_, m, k), rSin(_, m, k));
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS) / 2; ++i) {
|
||||
float real = S_fp32(2 * i) * cos_fp32(i) - S_fp32(2 * i + 1) * sin_fp32(i);
|
||||
float imag = S_fp32(2 * i) * sin_fp32(i) + S_fp32(2 * i + 1) * cos_fp32(i);
|
||||
S_fp32(2 * i) = real;
|
||||
S_fp32(2 * i + 1) = imag;
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <bool Is_even_K=true, bool Clear_OOB_K=true,
|
||||
typename Engine0, typename Layout0, typename Engine1, typename Layout1,
|
||||
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
|
||||
inline __device__ void copy_rotary_contiguous(Tensor<Engine0, Layout0> const &S,
|
||||
Tensor<Engine1, Layout1> &D,
|
||||
Tensor<Engine2, Layout2> const &Cos,
|
||||
Tensor<Engine2, Layout2> const &Sin,
|
||||
Tensor<Engine3, Layout3> const &identity_MN,
|
||||
const int max_MN, const int min_MN,
|
||||
const int dim, const int rotary_dim) {
|
||||
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Cos)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Cos)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(Sin)); // MMA_M
|
||||
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(Sin)); // MMA_K
|
||||
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(Cos)); // MMA
|
||||
CUTE_STATIC_ASSERT_V(size<0>(Cos) == size<0>(Sin));
|
||||
static_assert(decltype(size<0>(Cos))::value % 2 == 0); // Since we do fast conversion from fp16/bf16 to fp32
|
||||
Tensor rCos = make_fragment_like(Cos);
|
||||
Tensor rSin = make_fragment_like(Sin);
|
||||
Tensor rS = make_fragment_like(S);
|
||||
Tensor rS_other = make_fragment_like(rS(_, 0, 0));
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(S); ++m) {
|
||||
if (get<0>(identity_MN(0, m, 0)) >= min_MN && get<0>(identity_MN(0, m, 0)) < max_MN) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size<2>(S); ++k) {
|
||||
if (Is_even_K || get<1>(identity_MN(0, 0, k)) < dim) {
|
||||
cute::copy(S(_, m, k), rS(_, m, k));
|
||||
if (get<1>(identity_MN(0, 0, k)) < rotary_dim) {
|
||||
const bool is_left = get<1>(identity_MN(0, 0, k)) < rotary_dim / 2;
|
||||
Tensor gS_other = make_tensor(S(_, m, k).data() + (is_left ? rotary_dim / 2 : -rotary_dim / 2), S(_, m, k).layout());
|
||||
cute::copy(gS_other, rS_other);
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); print_tensor(rS_other); }
|
||||
Tensor gCos = make_tensor(Cos(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Cos(_, m, k).layout());
|
||||
Tensor gSin = make_tensor(Sin(_, m, k).data() + (is_left ? 0 : -rotary_dim / 2), Sin(_, m, k).layout());
|
||||
cute::copy(gCos, rCos(_, m, k));
|
||||
cute::copy(gSin, rSin(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rCos(_, m, k)); print_tensor(rSin(_, m, k)); }
|
||||
Tensor S_fp32 = convert_type<float>(rS(_, m, k));
|
||||
Tensor S_other_fp32 = convert_type<float>(rS_other);
|
||||
Tensor cos_fp32 = convert_type<float>(rCos(_, m, k));
|
||||
Tensor sin_fp32 = convert_type<float>(rSin(_, m, k));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size<0>(rS); ++i) {
|
||||
S_fp32(i) = S_fp32(i) * cos_fp32(i) + S_other_fp32(i) * (is_left ? -sin_fp32(i) : sin_fp32(i));
|
||||
}
|
||||
// Idk but I need to copy for the convert_type to work
|
||||
Tensor S_fp32_copy = make_fragment_like(S_fp32);
|
||||
cute::copy(S_fp32, S_fp32_copy);
|
||||
using T = typename Engine0::value_type;
|
||||
Tensor S_og_type = convert_type<T>(S_fp32_copy);
|
||||
cute::copy(S_og_type, rS(_, m, k));
|
||||
// if (cute::thread0()) { print_tensor(rS(_, m, k)); }
|
||||
}
|
||||
cute::copy(rS(_, m, k), D(_, m, k));
|
||||
} else if (Clear_OOB_K) {
|
||||
cute::clear(D(_, m, k));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
||||
|
@ -242,7 +242,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90(
|
||||
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
|
||||
sdp_params const& params,
|
||||
bool debug) {
|
||||
// Flash Attention will raise an error in the backward pass if the head_dim
|
||||
@ -252,11 +252,19 @@ bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90(
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
|
||||
bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
|
||||
if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt192) {
|
||||
bool is_head_dim_lte224 = params.query.sym_size(-1) <= 224;
|
||||
bool is_dropout = params.dropout > 0.0;
|
||||
// head_dim size in (192, 224] is not supported on sm86 and sm89
|
||||
bool cond1 = is_head_dim_gt192 && is_head_dim_lte224;
|
||||
// head_dim size > 224 and is_dropout is not supported on sm86 and sm89
|
||||
bool cond2 = params.query.sym_size(-1) > 224 && is_dropout;
|
||||
if (input_requires_grad(params) && is_sm86_or_sm89 && (cond1 || cond2)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention currently doesn't support training with head_dim greater than 192 on gpu architectures in the range[sm86, sm89].",
|
||||
"Attempting to run with head_dim: ",
|
||||
"Flash attention currently doesn't support training with head_dim ∈ (192, 224] or "
|
||||
"(head_dim ∈ (224, 256] and dropout > 0.0) on gpu architectures in the range[sm86, sm89].",
|
||||
"Attempting to run with dropout set to: ", params.dropout,
|
||||
"and head_dim: ",
|
||||
params.query.sym_size(-1), " on a sm ", dprops->major, ".",
|
||||
dprops->minor, " gpu.");
|
||||
}
|
||||
@ -467,7 +475,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size_flash,
|
||||
check_flash_attention_hardware_support,
|
||||
check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90,
|
||||
check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89,
|
||||
check_flash_causal_non_square_seqlens,
|
||||
check_dtypes_low_precision);
|
||||
for (auto& constraint : general_constraints) {
|
||||
|
@ -106,10 +106,11 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
@ -311,13 +312,14 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
|
||||
const int max_seqlen_q,
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
bool is_causal,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
@ -343,11 +345,13 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
check_gpu_arch();
|
||||
@ -630,14 +634,16 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int window_size_left,
|
||||
int window_size_left,
|
||||
int window_size_right,
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
TORCH_CHECK(false, "mha_varlen_bwd not supported on ROCm");
|
||||
|
@ -1328,13 +1328,17 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
|
||||
"Does not support fused SDPA or not SM86+ hardware")
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
|
||||
"Does not support fused SDPA or not SM86+ hardware",
|
||||
)
|
||||
@parametrize("head_dim", [193, 204, 256])
|
||||
def test_flash_backward_failure_sm86plus(self, device, head_dim: int):
|
||||
@parametrize("dropout_p", [0.0, 0.2])
|
||||
def test_flash_backward_failure_sm86plus(self, device, head_dim: int, dropout_p: float):
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(torch.rand, device=device, dtype=dtype)
|
||||
# See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
|
||||
# See check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89 in
|
||||
# pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
|
||||
size = (2, 2, 4, head_dim)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
|
||||
@ -1351,8 +1355,15 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q = make_tensor(size, requires_grad=True)
|
||||
k = make_tensor(size, requires_grad=True)
|
||||
v = make_tensor(size, requires_grad=True)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
if 192 < head_dim <= 224 or (head_dim > 224 and dropout_p != 0.0):
|
||||
self.assertRaises(
|
||||
RuntimeError,
|
||||
lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, dropout_p, False
|
||||
),
|
||||
)
|
||||
else:
|
||||
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, dropout_p, False)
|
||||
|
||||
@onlyCUDA
|
||||
def test_dispatch_fails_no_backend(self, device):
|
||||
@ -1589,7 +1600,6 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isLessThanSM80Device,
|
||||
"Current platform does not support fused SDPA or is an SM80+ device.")
|
||||
@ -1670,37 +1680,35 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, is_causal=True))
|
||||
|
||||
def _get_block_size(device, head_dim, is_causal):
|
||||
def _get_block_size_n(device, head_dim, is_dropout, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
# Mask is only interesting when we are setting dropout
|
||||
is_dropout = True
|
||||
assert head_dim <= 256
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
||||
is_sm80 = major == 8 and minor == 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
if head_dim <= 32:
|
||||
return 128, 128
|
||||
return 128
|
||||
if head_dim <= 64:
|
||||
return (128, 128) if not is_dropout else (128, 64)
|
||||
return 128 if not is_dropout else 64
|
||||
elif head_dim <= 96:
|
||||
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
|
||||
return 64
|
||||
elif head_dim <= 128:
|
||||
if is_sm8x:
|
||||
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
|
||||
return 64 if (not is_dropout and is_causal) else 32
|
||||
else:
|
||||
return 128, (64 if not is_dropout else 32)
|
||||
return 64 if not is_dropout else 32
|
||||
elif head_dim <= 160:
|
||||
if is_sm8x:
|
||||
return (128, 64) if not is_causal else (64, 64)
|
||||
return 64
|
||||
else:
|
||||
return 128, 32
|
||||
return 32
|
||||
elif head_dim <= 192:
|
||||
return (128, 64) if not is_dropout else (64, 64)
|
||||
return 64
|
||||
elif head_dim <= 224:
|
||||
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
|
||||
return 64
|
||||
elif head_dim <= 256:
|
||||
return (128, 64) if is_sm80 else (64, 64)
|
||||
return 64
|
||||
|
||||
|
||||
def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
|
||||
@ -1963,7 +1971,114 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
_do_cuda_memory_leak_check = True
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
def convert_flash_attn_S_to_softmax(self, S, query_padding_mask, key_padding_mask, head_dim, causal=False):
|
||||
# TODO USED FOR TESTING THE SCORES, e.g. testing ALIBI we don't need this now
|
||||
def normalize_flash_attn_S(
|
||||
self,
|
||||
attn_unnorm,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
attn_bias=None,
|
||||
is_dropout=False,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
scale=None,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, head_dim)
|
||||
k, v: (batch_size, seqlen_k, nheads, head_dim)
|
||||
key_padding_mask: (batch_size, seqlen_q)
|
||||
attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
Output:
|
||||
softmax_lse: (batch_size, nheads, seqlen_q)
|
||||
softmax_max: (batch_size, nheads, seqlen_q)
|
||||
"""
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
if causal:
|
||||
window_size = (window_size[0], 0)
|
||||
q, k, v = q.float(), k.float(), v.float()
|
||||
_, seqlen_q, _, head_dim = q.shape
|
||||
seqlen_k = k.shape[1]
|
||||
b = q.shape[0]
|
||||
from torch.nn.attention.bias import _calculate_scale
|
||||
scale = _calculate_scale(head_dim, scale)
|
||||
scores = torch.matmul(q.transpose(1, 2) * scale, k.permute(0, 2, 3, 1))
|
||||
if key_padding_mask is not None:
|
||||
scores.masked_fill_(~key_padding_mask.view(b, 1, 1, -1), float("-inf"))
|
||||
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||
local_mask = self.construct_local_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
q.device,
|
||||
)
|
||||
scores.masked_fill_(local_mask, float("-inf"))
|
||||
if attn_bias is not None:
|
||||
scores = scores + attn_bias.to(dtype=scores.dtype)
|
||||
block_size_n = _get_block_size_n(scores.device, head_dim, is_dropout, causal)
|
||||
scores_block = scores.split(block_size_n, dim=-1)
|
||||
lse_block = torch.stack([torch.logsumexp(s, dim=-1) for s in scores_block], dim=-1)
|
||||
lse = torch.logsumexp(lse_block, dim=-1)
|
||||
# lse could be -inf (i.e. all values in scores are -inf), and we want to set those to inf
|
||||
# so that when we do torch.exp(m - lse), we get 0.0 instead of NaN.
|
||||
lse[lse == float("-inf")] = float("inf")
|
||||
scores_max_block = torch.stack([torch.amax(s, dim=-1) for s in scores_block], dim=-1)
|
||||
cummax_block = torch.cummax(scores_max_block.flip(-1), dim=-1).values.flip(-1).unbind(dim=-1)
|
||||
attn_unnorm_block = attn_unnorm.split(block_size_n, dim=-1)
|
||||
attn_norm = torch.cat(
|
||||
[
|
||||
a * (torch.exp(m - lse)).unsqueeze(-1)
|
||||
for a, m in zip(attn_unnorm_block, cummax_block)
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
if query_padding_mask is not None:
|
||||
attn_norm.masked_fill_(~query_padding_mask.view(b, 1, -1, 1), 0.0)
|
||||
# attn_norm.masked_fill_(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
||||
return attn_norm.to(dtype=attn_unnorm.dtype)
|
||||
|
||||
def construct_local_mask(self, seqlen_q, seqlen_k, window_size, query_padding_mask, key_padding_mask, device):
|
||||
# row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1")
|
||||
row_idx = torch.arange(seqlen_q, device=device, dtype=torch.long).view(-1, 1)
|
||||
col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long)
|
||||
sk = (
|
||||
seqlen_k
|
||||
if key_padding_mask is None
|
||||
else key_padding_mask.sum(-1).view(-1, 1, 1, 1)
|
||||
# else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
sq = (
|
||||
seqlen_q
|
||||
if query_padding_mask is None
|
||||
else query_padding_mask.sum(-1).view(-1, 1, 1, 1)
|
||||
# else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1")
|
||||
)
|
||||
if window_size[0] < 0:
|
||||
return col_idx > row_idx + sk - sq + window_size[1]
|
||||
else:
|
||||
sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk
|
||||
return torch.logical_or(
|
||||
col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk),
|
||||
col_idx < row_idx + sk - sq - window_size[0],
|
||||
)
|
||||
|
||||
def convert_flash_attn_S_to_softmax(
|
||||
self,
|
||||
S,
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
causal=False,
|
||||
window_size=(-1, -1), # -1 means infinite window size
|
||||
):
|
||||
"""FlashAttention stores the S matrix in a different way.
|
||||
Arguments:
|
||||
S: (batch_size, nheads, seqlen_q, seqlen_k)
|
||||
@ -1972,53 +2087,45 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
"""
|
||||
if TEST_WITH_ROCM:
|
||||
return S
|
||||
|
||||
b, h, seqlen_q, seqlen_k = S.shape
|
||||
warps_n = 4
|
||||
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
|
||||
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
|
||||
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
|
||||
mmas_n = (blocksize_n + 16 - 1) // 16
|
||||
|
||||
# Reshape S using PyTorch native functions
|
||||
S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n)
|
||||
S_flat = S_flat.permute(0, 1, 2, 4, 3, 5)
|
||||
S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n))
|
||||
S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2)
|
||||
S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11)
|
||||
S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) *
|
||||
warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2))
|
||||
b = S.shape[0]
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
|
||||
S_converted.masked_fill_(causal_mask, 0.0)
|
||||
window_size = (window_size[0], 0)
|
||||
seqlen_q_rounded, seqlen_k_rounded = S.shape[-2:]
|
||||
S_converted = S
|
||||
if window_size[0] >= 0 or window_size[1] >= 0:
|
||||
local_mask = self.construct_local_mask(
|
||||
seqlen_q,
|
||||
seqlen_k,
|
||||
window_size,
|
||||
query_padding_mask,
|
||||
key_padding_mask,
|
||||
S.device,
|
||||
)
|
||||
local_mask = F.pad(
|
||||
local_mask,
|
||||
(0, seqlen_k_rounded - seqlen_k, 0, seqlen_q_rounded - seqlen_q),
|
||||
value=True,
|
||||
)
|
||||
S_converted = S_converted.masked_fill(local_mask, 0.0)
|
||||
|
||||
# Need to zero out things not in attention_mask in case S was initialized with random values
|
||||
# and some of those values aren't overwritten.
|
||||
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
|
||||
seqlen_q_og = (
|
||||
query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q_rounded
|
||||
)
|
||||
if query_padding_mask is not None:
|
||||
if seqlen_q_og < seqlen_q:
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
|
||||
else:
|
||||
query_padding_mask = query_padding_mask[:, :seqlen_q]
|
||||
q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
|
||||
S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q_rounded - seqlen_q_og))
|
||||
# S_converted = S_converted.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
|
||||
S_converted = S_converted.masked_fill(~query_padding_mask.view(b, 1, -1, 1), 0.0)
|
||||
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
|
||||
if key_padding_mask is not None:
|
||||
if seqlen_k_og < seqlen_k:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
|
||||
else:
|
||||
key_padding_mask = key_padding_mask[:, :seqlen_k]
|
||||
k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1])
|
||||
S_converted = S_converted.masked_fill(k_mask_fill, 0.0)
|
||||
if seqlen_q_og < seqlen_q:
|
||||
S_converted = S_converted[:, :, :seqlen_q_og, :]
|
||||
else:
|
||||
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q))
|
||||
if seqlen_k_og < seqlen_k:
|
||||
S_converted = S_converted[:, :, :, :seqlen_k_og]
|
||||
else:
|
||||
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k))
|
||||
return S_converted
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k_rounded - seqlen_k_og))
|
||||
S_converted = S_converted.masked_fill(~key_padding_mask.view(b, 1, 1, -1), 0.0)
|
||||
# S_converted = S_converted.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0)
|
||||
S_converted = F.pad(S_converted, (0, 0, 0, seqlen_q_og - seqlen_q_rounded))
|
||||
S_converted = F.pad(S_converted, (0, seqlen_k_og - seqlen_k_rounded))
|
||||
return S_converted[:, :, :seqlen_q, :seqlen_k]
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||
@ -2370,28 +2477,29 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
|
||||
with use_deterministic_algorithims(True, warn_only=warn_only):
|
||||
# Note that this should swith to a testing version with we remove old context manager
|
||||
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]):
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION.value
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
|
||||
@parametrize("warn_only", [True, False])
|
||||
def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only):
|
||||
def test_fused_backwards_throws_determinism_warning(self, device, warn_only, fused_kernel):
|
||||
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
|
||||
shape = SdpaShape(batch_size, num_heads, seq_len, head_dim)
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float32, packed=False, requires_grad=True)
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=torch.float16, packed=False, requires_grad=True)
|
||||
query, key, value = make_tensor(shape), make_tensor(shape), make_tensor(shape)
|
||||
|
||||
kernel_name = "Memory Efficient attention" if fused_kernel == SDPBackend.EFFICIENT_ATTENTION else "Flash Attention"
|
||||
warning_context = (
|
||||
self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"Memory Efficient attention defaults to a non-deterministic algorithm.",
|
||||
f"{kernel_name} defaults to a non-deterministic algorithm.",
|
||||
)
|
||||
if warn_only
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with use_deterministic_algorithims(True, warn_only=warn_only):
|
||||
with sdpa_kernel(backends=[SDPBackend.EFFICIENT_ATTENTION]):
|
||||
with sdpa_kernel(backends=[fused_kernel]):
|
||||
with warning_context:
|
||||
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
|
||||
|
||||
@ -2710,8 +2818,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
is_dropout = dropout_p > 0.0
|
||||
|
||||
if not is_dropout:
|
||||
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
|
||||
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
|
||||
with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
|
||||
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
with sdpa_kernel(backends=[SDPBackend.MATH]):
|
||||
@ -2722,6 +2828,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out_lp_ref = F.scaled_dot_product_attention(
|
||||
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
|
||||
else:
|
||||
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
|
||||
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
|
||||
q_padded, q_og_size = pad_last_dim(query, 8)
|
||||
k_padded, k_og_size = pad_last_dim(key, 8)
|
||||
v_padded, v_og_size = pad_last_dim(value, 8)
|
||||
@ -2740,9 +2848,14 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
batch_size, seq_len_k, device=device, dtype=torch.bool)
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim,
|
||||
dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
|
||||
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
|
||||
dropout_mask = softmax_mask >= 0
|
||||
# attn_unnorm = softmax_mask.abs()
|
||||
# attn = self.normalize_flash_attn_S(attn_unnorm, q_padded,
|
||||
# k_padded, v_padded, query_padding_mask,
|
||||
# key_padding_mask, None, True, is_causal, scale=scale)
|
||||
|
||||
# High Precision Math Reference
|
||||
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
|
||||
@ -2823,7 +2936,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
batch_size, seq_len_k, device=device, dtype=torch.bool)
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
|
||||
dbug_mask, seq_len_q, seq_len_k, query_padding_mask, key_padding_mask,
|
||||
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
|
||||
dropout_mask = softmax_mask >= 0
|
||||
return dropout_mask
|
||||
|
||||
@ -3178,7 +3292,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
key_padding_mask = key_padding_mask.to("cuda")
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
|
||||
dbug_mask, max_seq_len_q, max_seq_len_kv, query_padding_mask, key_padding_mask, causal=is_causal)
|
||||
dropout_mask = softmax_mask >= 0
|
||||
nt_stack = []
|
||||
for tensor_component in range(batch_size):
|
||||
|
Reference in New Issue
Block a user