mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This reverts commit 24e9bbe22af296048f8242c6112d13cff726c588. Reverted https://github.com/pytorch/pytorch/pull/108827 on behalf of https://github.com/huydhn due to I need to land this revert properly as there are new failures showing up on trunk ([comment](https://github.com/pytorch/pytorch/pull/108827#issuecomment-1711020924))
This commit is contained in:
@ -159,6 +159,14 @@ if [[ "$BUILD_ENVIRONMENT" == *cuda* && -z "$TORCH_CUDA_ARCH_LIST" ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# We only build FlashAttention files for CUDA 8.0+, and they require large amounts of
|
||||
# memory to build and will OOM
|
||||
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]] && [[ "$TORCH_CUDA_ARCH_LIST" == *"8.6"* || "$TORCH_CUDA_ARCH_LIST" == *"8.0"* ]]; then
|
||||
echo "WARNING: FlashAttention files require large amounts of memory to build and will OOM"
|
||||
echo "Setting MAX_JOBS=(nproc-2)/3 to reduce memory usage"
|
||||
export MAX_JOBS="$(( $(nproc --ignore=2) / 3 ))"
|
||||
fi
|
||||
|
||||
if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then
|
||||
export CC=clang
|
||||
export CXX=clang++
|
||||
|
||||
@ -155,8 +155,8 @@ EOL
|
||||
|
||||
# nproc doesn't exist on darwin
|
||||
if [[ "$(uname)" != Darwin ]]; then
|
||||
# Because most Circle executors only have 20 CPUs, using more causes OOMs w/ Ninja and nvcc parallelization
|
||||
MEMORY_LIMIT_MAX_JOBS=18
|
||||
# This was lowered from 18 to 12 to avoid OOMs when compiling FlashAttentionV2
|
||||
MEMORY_LIMIT_MAX_JOBS=12
|
||||
NUM_CPUS=$(( $(nproc) - 2 ))
|
||||
|
||||
# Defaults here for **binary** linux builds so they can be changed in one place
|
||||
|
||||
@ -730,7 +730,7 @@ include(cmake/Dependencies.cmake)
|
||||
cmake_dependent_option(
|
||||
USE_FLASH_ATTENTION
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
||||
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
"USE_CUDA AND NOT ROCM AND NOT MSVC AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
|
||||
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
|
||||
@ -160,6 +160,7 @@ file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
|
||||
# flash_attention sources
|
||||
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
||||
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
|
||||
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||
|
||||
#Mem_eff attention sources
|
||||
@ -169,6 +170,7 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention
|
||||
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_kernels_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
||||
endif()
|
||||
|
||||
|
||||
@ -14278,7 +14278,7 @@
|
||||
variants: function
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
dispatch:
|
||||
CPU: _scaled_dot_product_flash_attention_cpu
|
||||
CUDA: _scaled_dot_product_flash_attention_cuda
|
||||
@ -14304,7 +14304,7 @@
|
||||
CUDA: _scaled_dot_product_efficient_attention_backward_cuda
|
||||
tags: nondeterministic_seeded
|
||||
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: _flash_attention_forward
|
||||
|
||||
@ -679,16 +679,7 @@ inline auto sdpa_nested_preprocessing(
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor,
|
||||
int64_t,
|
||||
int64_t,
|
||||
Tensor,
|
||||
Tensor,
|
||||
Tensor>
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Tensor>
|
||||
_scaled_dot_product_flash_attention_nestedtensor_cuda(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
@ -710,8 +701,12 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
|
||||
max_seqlen_batch_kv,
|
||||
output_shape) = sdpa_nested_preprocessing(query, key, value);
|
||||
|
||||
Tensor attention, log_sumexp, debug_attn_mask, philox_seed, philox_offset;
|
||||
std::tie(attention, log_sumexp, philox_seed, philox_offset, debug_attn_mask) =
|
||||
auto
|
||||
[attention,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask] =
|
||||
at::_flash_attention_forward(
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
@ -728,7 +723,7 @@ _scaled_dot_product_flash_attention_nestedtensor_cuda(
|
||||
attention = wrap_buffer(attention.view(-1), output_shape).transpose(1, 2);
|
||||
return std::make_tuple(
|
||||
attention,
|
||||
log_sumexp,
|
||||
logsumexp,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_kv,
|
||||
max_seqlen_batch_q,
|
||||
|
||||
@ -14,6 +14,7 @@
|
||||
#include <ATen/native/transformers/sdp_utils_cpp.h>
|
||||
#include <utility>
|
||||
#include <c10/util/typeid.h>
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/core/SymIntArrayRef.h>
|
||||
#include <c10/util/Logging.h>
|
||||
#include <c10/util/Exception.h>
|
||||
@ -629,6 +630,37 @@ at::Tensor preprocess_mask(
|
||||
|
||||
return attn_mask;
|
||||
}
|
||||
// FlashAttentionV2 requires that head dimension be a multiple of 8
|
||||
// This was previously done within the kernel, however
|
||||
// This causes the kernel to maybe alias query, key, value
|
||||
// So instead we pad the head_dimensions to be a multiple of 8 in the composite
|
||||
// region
|
||||
template <int alignment_size, bool slice>
|
||||
at::Tensor pad_last_dim(const at::Tensor& attn_bias) {
|
||||
auto last_dim_size = attn_bias.sym_size(-1);
|
||||
if (last_dim_size % alignment_size == 0) {
|
||||
return attn_bias;
|
||||
}
|
||||
auto pad_count = alignment_size - (last_dim_size % alignment_size);
|
||||
auto padded_bias = at::pad_symint(attn_bias, {c10::SymInt(0), pad_count});
|
||||
if (slice) {
|
||||
return padded_bias.slice_symint(-1, 0, last_dim_size);
|
||||
}
|
||||
return padded_bias;
|
||||
}
|
||||
|
||||
at::Tensor post_process_flash_output(
|
||||
at::Tensor out,
|
||||
c10::SymInt const& og_size) {
|
||||
if (!out.is_nested()) {
|
||||
out = out.slice_symint(-1, 0, og_size);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
out.size(-1) == og_size,
|
||||
"FlashAttentionV2 returned a nested tensor with an incorrect size")
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -679,6 +711,18 @@ Tensor scaled_dot_product_attention(
|
||||
c10::optional<Tensor> attn_mask = convert_boolean_attn_mask(attn_mask_, query_.dtype());
|
||||
switch (backend) {
|
||||
case sdp::SDPBackend::flash_attention: {
|
||||
if(query_.device().type() == DeviceType::CUDA){
|
||||
c10::SymInt og_size = query_.sym_size(-1);
|
||||
Tensor query_padded = pad_last_dim<8, false>(query_);
|
||||
Tensor key_padded = pad_last_dim<8, false>(key);
|
||||
Tensor value_padded = pad_last_dim<8, false>(value);
|
||||
// We need to calculate the scale based off the OG head dim size
|
||||
auto og_scale = sdp::calculate_scale(query_, scale);
|
||||
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
|
||||
query_padded, key_padded, value_padded, dropout_p, is_causal, false /*return_debug_mask*/, og_scale.as_float_unchecked());
|
||||
return post_process_flash_output(std::get<0>(out_lse_softmax), og_size);
|
||||
}
|
||||
// For the CPU case we do not need to pad the last dim
|
||||
auto out_lse_softmax = at::_scaled_dot_product_flash_attention(
|
||||
query_, key, value, dropout_p, is_causal, false /*return_debug_mask*/, scale);
|
||||
return std::get<0>(out_lse_softmax);
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#include <ATen/native/cuda/MemoryAccess.cuh>
|
||||
#include <ATen/native/cuda/PersistentSoftmax.cuh>
|
||||
#include <ATen/native/cuda/block_reduce.cuh>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -37,7 +38,7 @@
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
// FlashAttention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
@ -654,63 +655,43 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, int64_t, int64_t, Tensor, Tensor, Ten
|
||||
// Query (Batch x Num_heads x Q_seq_len x Dim_per_head)
|
||||
// Key (Batch x Num_heads x KV_seq_len x Dim_per_head)
|
||||
// Value (Batch x Num_heads x KV_seq_len x Dim_per_head)
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
const int64_t max_seqlen_batch_q = query.size(2);
|
||||
const int64_t max_seqlen_batch_k = key.size(2);
|
||||
const int64_t max_seqlen_batch_v = value.size(2);
|
||||
TORCH_CHECK(
|
||||
max_seqlen_batch_k == max_seqlen_batch_v,
|
||||
"Key and Value must have the same sequence length");
|
||||
|
||||
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
|
||||
// Key -> Key(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
// Query -> Query(Batch x Q_seq_len x Num_heads x Dim_per_head)
|
||||
// Key -> Key (Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
// Value -> Value(Batch x KV_seq_len x Num_heads x Dim_per_head)
|
||||
Tensor q_t = query.transpose(1, 2);
|
||||
Tensor k_t = key.transpose(1, 2);
|
||||
Tensor v_t = value.transpose(1, 2);
|
||||
|
||||
Tensor cumulative_sequence_length_q = at::arange(
|
||||
0,
|
||||
(batch_size + 1) * max_seqlen_batch_q,
|
||||
max_seqlen_batch_q,
|
||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||
|
||||
Tensor cumulative_sequence_length_k = at::arange(
|
||||
0,
|
||||
(batch_size + 1) * max_seqlen_batch_k,
|
||||
max_seqlen_batch_k,
|
||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||
|
||||
int64_t Nnz_q{batch_size * max_seqlen_batch_q};
|
||||
int64_t Nnz_kv{batch_size * max_seqlen_batch_k};
|
||||
|
||||
// For the standard MHA these will actually be views
|
||||
Tensor query_reshaped = q_t.reshape({Nnz_q, num_heads, head_dim});
|
||||
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
Tensor attention, log_sumexp, debug_attn_mask, philox_seed, philox_offset;
|
||||
std::tie(attention, log_sumexp, philox_seed, philox_offset, debug_attn_mask) =
|
||||
at::_flash_attention_forward(
|
||||
query_reshaped,
|
||||
key_reshaped,
|
||||
value_reshaped,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
scale);
|
||||
auto
|
||||
[output,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask] =
|
||||
at::_flash_attention_forward(
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
scale);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
attention =
|
||||
attention.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
|
||||
Tensor attention = output.transpose(1,2);
|
||||
|
||||
return std::make_tuple(attention, log_sumexp, cumulative_sequence_length_q, cumulative_sequence_length_k, max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
|
||||
return std::make_tuple(attention, logsumexp, Tensor(), Tensor(), max_seqlen_batch_q, max_seqlen_batch_k, philox_seed, philox_offset, debug_attn_mask);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attention_cuda(
|
||||
@ -743,7 +724,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor> _scaled_dot_product_efficient_attenti
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
c10::nullopt,
|
||||
dropout_p /*dropout_p*/,
|
||||
dropout_p,
|
||||
static_cast<int64_t>(custom_mask_type),
|
||||
compute_log_sumexp,
|
||||
scale);
|
||||
@ -765,52 +746,102 @@ int64_t _fused_sdp_choice_cuda(const Tensor& query_, const Tensor& key, const Te
|
||||
return static_cast<int64_t>(backend);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor> _flash_attention_forward(
|
||||
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>
|
||||
_flash_attention_forward(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& cumulative_sequence_length_q,
|
||||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
const c10::optional<Tensor>& cumulative_sequence_length_q,
|
||||
const c10::optional<Tensor>& cumulative_sequence_length_k,
|
||||
c10::optional<int64_t> max_seqlen_batch_q,
|
||||
c10::optional<int64_t> max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
bool return_debug_mask,
|
||||
c10::optional<double> scale) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
/*
|
||||
num_splits determines how much to parallelize over the seqlen_q dimension
|
||||
num_splits=0 means
|
||||
it will be set by an internal heuristic. We're exposing num_splits mostly for
|
||||
benchmarking. We will hard code it to 0 for now
|
||||
*/
|
||||
constexpr int num_splits{0};
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
at::Tensor output = at::empty_like(query);
|
||||
|
||||
auto [logsumexp, philox_seed, philox_offset, debug_attn_mask] = pytorch_fmha::mha_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false, /*zero_tensors = false for all calls here*/
|
||||
is_causal,
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
num_splits);
|
||||
const auto softmax_scale =
|
||||
sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
c10::optional<Tensor> out = c10::nullopt;
|
||||
|
||||
// We are going to have two paths:
|
||||
// 1. The standard MHA path for dense tensors
|
||||
// 2. The Varseqlen path
|
||||
TORCH_CHECK(
|
||||
cumulative_sequence_length_q.has_value() ==
|
||||
cumulative_sequence_length_k.has_value(),
|
||||
"cumulative_sequence_length_q and cumulative_sequence_length_k must be both set or both not set");
|
||||
TORCH_CHECK(
|
||||
max_seqlen_batch_q.has_value() == max_seqlen_batch_k.has_value(),
|
||||
"max_seqlen_batch_q and max_seqlen_batch_k must be both set or both not set");
|
||||
Tensor output, q_padded, k_padded, v_padded, logsumexp, output_shape,
|
||||
philox_seed, philox_offset, debug_attn_mask;
|
||||
if (cumulative_sequence_length_q.has_value()) {
|
||||
TORCH_CHECK(
|
||||
max_seqlen_batch_q.has_value(),
|
||||
"max_seqlen_batch_q must be set when cumulative_sequence_length_q is set");
|
||||
std::tie(
|
||||
output,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask) =
|
||||
pytorch_flash::mha_varlen_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
cumulative_sequence_length_q.value(),
|
||||
cumulative_sequence_length_k.value(),
|
||||
max_seqlen_batch_q.value(),
|
||||
max_seqlen_batch_k.value(),
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
return_debug_mask,
|
||||
c10::nullopt /*gen_*/);
|
||||
} else {
|
||||
std::tie(
|
||||
output,
|
||||
q_padded,
|
||||
k_padded,
|
||||
v_padded,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask) =
|
||||
pytorch_flash::mha_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
return_debug_mask, /*return_softmax (this is used for testing)*/
|
||||
c10::nullopt);
|
||||
}
|
||||
debug_attn_mask =
|
||||
return_debug_mask ? debug_attn_mask : at::empty({0}, query.options());
|
||||
return std::make_tuple(
|
||||
output,
|
||||
logsumexp,
|
||||
philox_seed,
|
||||
philox_offset,
|
||||
debug_attn_mask);
|
||||
|
||||
return std::make_tuple(output, logsumexp, philox_seed, philox_offset, debug_attn_mask);
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor());
|
||||
return std::make_tuple(
|
||||
Tensor(),
|
||||
Tensor(),
|
||||
Tensor(),
|
||||
Tensor(),
|
||||
Tensor());
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||
|
||||
@ -18,7 +18,7 @@
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
// FlashAttention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
@ -40,54 +40,71 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
|
||||
const Tensor& logsumexp,
|
||||
const Tensor& cumulative_sequence_length_q,
|
||||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
int64_t max_seqlen_batch_q,
|
||||
int64_t max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool is_causal,
|
||||
const Tensor& philox_seed,
|
||||
const Tensor& philox_offset,
|
||||
c10::optional<double> scale) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
/*
|
||||
num_splits determines how much to parallelize over the seqlen_q dimension
|
||||
num_splits=0 means
|
||||
it will be set by an internal heuristic. We're exposing num_splits mostly for
|
||||
benchmarking. We will hard code it to 0 for now
|
||||
*/
|
||||
constexpr int num_splits{0};
|
||||
const auto softmax_scale = sdp::calculate_scale(query, scale).as_float_unchecked();
|
||||
// CUDA code assumes that dout is contiguous
|
||||
auto contiguous_grad_out = grad_out.contiguous();
|
||||
auto contiguous_out = out.contiguous();
|
||||
Tensor dq = at::empty_like(query);
|
||||
Tensor dk = at::empty_like(key);
|
||||
Tensor dv = at::empty_like(value);
|
||||
|
||||
c10::optional<at::Tensor> dq{c10::nullopt};
|
||||
c10::optional<at::Tensor> dk{c10::nullopt};
|
||||
c10::optional<at::Tensor> dv{c10::nullopt};
|
||||
|
||||
// The kernel computes irregadless we will drop for this functions return
|
||||
Tensor grad_softmax;
|
||||
|
||||
std::tie(dq, dk, dv, grad_softmax) = pytorch_fmha::mha_bwd(
|
||||
contiguous_grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
contiguous_out,
|
||||
logsumexp,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false, /*zero_tensors = false for all calls here*/
|
||||
is_causal,
|
||||
num_splits,
|
||||
philox_seed,
|
||||
philox_offset
|
||||
);
|
||||
return std::make_tuple(dq, dk, dv);
|
||||
// We check the whether the cumulative_sequence_length_q is defined
|
||||
// in order to determine whether we are using varlen or dense forward
|
||||
if (cumulative_sequence_length_q.defined()) {
|
||||
// Varlen forward
|
||||
TORCH_CHECK(false, "Dont go down this path yet");
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
|
||||
contiguous_grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
contiguous_out,
|
||||
logsumexp,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false /*zero_tensors*/,
|
||||
is_causal,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(dQuery, dKey, dValue);
|
||||
} else {
|
||||
// Dense forward
|
||||
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
|
||||
contiguous_grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
contiguous_out,
|
||||
logsumexp,
|
||||
dq,
|
||||
dk,
|
||||
dv,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
philox_seed,
|
||||
philox_offset);
|
||||
return std::make_tuple(dQuery, dKey, dValue);
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.");
|
||||
return std::make_tuple(Tensor(), Tensor(), Tensor());
|
||||
@ -499,32 +516,20 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attenti
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
|
||||
const int64_t batch_size = query.size(0);
|
||||
const int64_t num_heads = query.size(1);
|
||||
const int64_t head_dim = query.size(3);
|
||||
|
||||
Tensor q_t = query.transpose(1, 2);
|
||||
Tensor k_t = key.transpose(1, 2);
|
||||
Tensor v_t = value.transpose(1, 2);
|
||||
|
||||
int64_t Nnz_q{batch_size * max_seqlen_batch_q};
|
||||
int64_t Nnz_kv{batch_size * max_seqlen_batch_k};
|
||||
|
||||
// For the standard MHA these will actually be views
|
||||
Tensor query_reshaped = q_t.reshape({Nnz_q, num_heads, head_dim});
|
||||
Tensor key_reshaped = k_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
Tensor value_reshaped = v_t.reshape({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
auto grad_out_reshaped = grad_out_.transpose(1,2).reshape({{Nnz_q, num_heads, head_dim}});
|
||||
auto out_reshaped = out.transpose(1,2).reshape({Nnz_q, num_heads, head_dim});
|
||||
Tensor grad_out_t = grad_out_.transpose(1,2);
|
||||
Tensor out_t = out.transpose(1,2);
|
||||
|
||||
Tensor grad_q, grad_k, grad_v;
|
||||
std::tie(grad_q, grad_k, grad_v) = at::_flash_attention_backward(
|
||||
grad_out_reshaped,
|
||||
query_reshaped,
|
||||
key_reshaped,
|
||||
value_reshaped,
|
||||
out_reshaped,
|
||||
grad_out_t,
|
||||
q_t,
|
||||
k_t,
|
||||
v_t,
|
||||
out_t,
|
||||
logsumexp,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
@ -536,9 +541,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> _scaled_dot_product_flash_attenti
|
||||
philox_offset,
|
||||
scale);
|
||||
|
||||
grad_q = grad_q.view({batch_size, max_seqlen_batch_q, num_heads, head_dim}).transpose(1,2);
|
||||
grad_k = grad_k.view({batch_size, max_seqlen_batch_k, num_heads, head_dim}).transpose(1,2);
|
||||
grad_v = grad_v.view({batch_size, max_seqlen_batch_k, num_heads, head_dim}).transpose(1,2);
|
||||
grad_q = grad_q.transpose(1,2);
|
||||
grad_k = grad_k.transpose(1,2);
|
||||
grad_v = grad_v.transpose(1,2);
|
||||
|
||||
return std::make_tuple(grad_q, grad_k, grad_v);
|
||||
}
|
||||
|
||||
@ -0,0 +1,41 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Varlen=true>
|
||||
struct BlockInfo {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfo(const Params ¶ms, const int bidb)
|
||||
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb])
|
||||
, sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb])
|
||||
, actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
|
||||
, actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_k[bidb + 1] - sum_s_k)
|
||||
{
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
|
||||
}
|
||||
|
||||
template <typename index_t>
|
||||
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
|
||||
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
|
||||
}
|
||||
|
||||
const int sum_s_q;
|
||||
const int sum_s_k;
|
||||
const uint32_t actual_seqlen_q;
|
||||
const uint32_t actual_seqlen_k;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace pytorch_flash
|
||||
143
aten/src/ATen/native/transformers/cuda/flash_attn/flash.h
Normal file
143
aten/src/ATen/native/transformers/cuda/flash_attn/flash.h
Normal file
@ -0,0 +1,143 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
using index_t = uint32_t;
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
index_t q_batch_stride;
|
||||
index_t k_batch_stride;
|
||||
index_t v_batch_stride;
|
||||
index_t q_row_stride;
|
||||
index_t k_row_stride;
|
||||
index_t v_row_stride;
|
||||
index_t q_head_stride;
|
||||
index_t k_head_stride;
|
||||
index_t v_head_stride;
|
||||
|
||||
// The number of heads.
|
||||
int h, h_k;
|
||||
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
|
||||
// different from nheads (query).
|
||||
int h_h_k_ratio; // precompute h / h_k,
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_fwd_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
index_t o_batch_stride;
|
||||
index_t o_row_stride;
|
||||
index_t o_head_stride;
|
||||
|
||||
// The pointer to the P matrix.
|
||||
void * __restrict__ p_ptr;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_softmax;
|
||||
float scale_softmax_log2;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
// uint32_t p_dropout_in_uint;
|
||||
// uint16_t p_dropout_in_uint16_t;
|
||||
uint8_t p_dropout_in_uint8_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_softmax_rp_dropout;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
int64_t * extragraph_offset;
|
||||
int64_t * seed;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Flash_bwd_params : public Flash_fwd_params {
|
||||
|
||||
// The dO and dQKV matrices.
|
||||
void *__restrict__ do_ptr;
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// To accumulate dQ
|
||||
void *__restrict__ dq_accum_ptr;
|
||||
void *__restrict__ dk_accum_ptr;
|
||||
void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q
|
||||
// dimension void *__restrict__ dk_accum_ptr; void *__restrict__
|
||||
// dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dO, dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
index_t do_batch_stride;
|
||||
index_t do_row_stride;
|
||||
index_t do_head_stride;
|
||||
index_t dq_batch_stride;
|
||||
index_t dk_batch_stride;
|
||||
index_t dv_batch_stride;
|
||||
index_t dq_row_stride;
|
||||
index_t dk_row_stride;
|
||||
index_t dv_row_stride;
|
||||
index_t dq_head_stride;
|
||||
index_t dk_head_stride;
|
||||
index_t dv_head_stride;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void *__restrict__ dsoftmax_sum;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
||||
|
||||
} // namespace pytorch_flash
|
||||
964
aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
Normal file
964
aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
Normal file
@ -0,0 +1,964 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#endif
|
||||
|
||||
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
|
||||
void set_params_fprop(Flash_fwd_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t seqlen_q_rounded,
|
||||
const size_t seqlen_k_rounded,
|
||||
const size_t h,
|
||||
const size_t h_k,
|
||||
const size_t d,
|
||||
const size_t d_rounded,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *p_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal) {
|
||||
|
||||
// Reset the parameters
|
||||
// TODO should be equivalent
|
||||
params = {};
|
||||
// memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.is_bf16 = q.dtype() == at::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.q_row_stride = q.stride(-3);
|
||||
params.k_row_stride = k.stride(-3);
|
||||
params.v_row_stride = v.stride(-3);
|
||||
params.q_head_stride = q.stride(-2);
|
||||
params.k_head_stride = k.stride(-2);
|
||||
params.v_head_stride = v.stride(-2);
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.o_row_stride = out.stride(-3);
|
||||
params.o_head_stride = out.stride(-2);
|
||||
|
||||
if (cu_seqlens_q_d == nullptr) {
|
||||
params.q_batch_stride = q.stride(0);
|
||||
params.k_batch_stride = k.stride(0);
|
||||
params.v_batch_stride = v.stride(0);
|
||||
params.o_batch_stride = out.stride(0);
|
||||
}
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
|
||||
// P = softmax(QK^T)
|
||||
params.p_ptr = p_d;
|
||||
|
||||
// Softmax sum
|
||||
params.softmax_lse_ptr = softmax_lse_d;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.h_k = h_k;
|
||||
params.h_h_k_ratio = h / h_k;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.seqlen_q_rounded = seqlen_q_rounded;
|
||||
params.seqlen_k_rounded = seqlen_k_rounded;
|
||||
params.d = d;
|
||||
params.d_rounded = d_rounded;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_softmax = softmax_scale;
|
||||
params.scale_softmax_log2 = softmax_scale * M_LOG2E;
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
// params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
// params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
|
||||
params.is_causal = is_causal;
|
||||
}
|
||||
|
||||
void set_params_dgrad(Flash_bwd_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t seqlen_q_rounded,
|
||||
const size_t seqlen_k_rounded,
|
||||
const size_t h,
|
||||
const size_t h_k,
|
||||
const size_t d,
|
||||
const size_t d_rounded,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor out,
|
||||
const at::Tensor dout,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *dq_accum_d,
|
||||
void *dk_accum_d,
|
||||
void *dv_accum_d,
|
||||
void *softmax_lse_d,
|
||||
void *dsoftmax_sum_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.do_ptr = dout.data_ptr();
|
||||
params.do_row_stride = dout.stride(-3);
|
||||
params.do_head_stride = dout.stride(-2);
|
||||
params.dq_ptr = dq.data_ptr();
|
||||
params.dk_ptr = dk.data_ptr();
|
||||
params.dv_ptr = dv.data_ptr();
|
||||
params.dq_row_stride = dq.stride(-3);
|
||||
params.dk_row_stride = dk.stride(-3);
|
||||
params.dv_row_stride = dv.stride(-3);
|
||||
params.dq_head_stride = dq.stride(-2);
|
||||
params.dk_head_stride = dk.stride(-2);
|
||||
params.dv_head_stride = dv.stride(-2);
|
||||
|
||||
if (cu_seqlens_q_d == nullptr) {
|
||||
params.do_batch_stride = dout.stride(0);
|
||||
params.dq_batch_stride = dq.stride(0);
|
||||
params.dk_batch_stride = dk.stride(0);
|
||||
params.dv_batch_stride = dv.stride(0);
|
||||
}
|
||||
|
||||
params.dq_accum_ptr = dq_accum_d;
|
||||
params.dk_accum_ptr = dk_accum_d;
|
||||
params.dv_accum_ptr = dv_accum_d;
|
||||
|
||||
// Softmax sum
|
||||
params.dsoftmax_sum = dsoftmax_sum_d;
|
||||
}
|
||||
|
||||
void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
FWD_HEADDIM_SWITCH(params.d, [&] {
|
||||
run_mha_fwd_<elem_type, kHeadDim>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
// return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p};
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == at::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
const int head_size_og = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
|
||||
} else {
|
||||
out = at::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
||||
p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
||||
}
|
||||
|
||||
Flash_fwd_params params;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q_padded, k_padded, v_padded, out,
|
||||
/*cu_seqlens_q_d=*/nullptr,
|
||||
/*cu_seqlens_k_d=*/nullptr,
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == at::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
const int head_size_og = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256");
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!")
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor q_padded, k_padded, v_padded;
|
||||
q_padded = q;
|
||||
k_padded = k;
|
||||
v_padded = v;
|
||||
|
||||
at::Tensor out;
|
||||
if (out_.has_value()) {
|
||||
out = out_.value();
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
|
||||
TORCH_CHECK(out.is_cuda(), "Output tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size_og);
|
||||
if (head_size_og % 8 != 0) { out = at::empty_like(q_padded); }
|
||||
} else {
|
||||
out = at::empty_like(q_padded);
|
||||
}
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size = round_multiple(head_size_og, 8);
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor p;
|
||||
// Only return softmax if there's dropout to reduce compilation time
|
||||
if (return_softmax) {
|
||||
TORCH_CHECK(p_dropout > 0.0f, "return_softmax is only supported when p_dropout > 0.0");
|
||||
p = at::empty({ batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded }, opts);
|
||||
}
|
||||
|
||||
if (zero_tensors) {
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {p.zero_();}
|
||||
}
|
||||
|
||||
Flash_fwd_params params;
|
||||
set_params_fprop(params,
|
||||
batch_size,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q_padded, k_padded, v_padded, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
return_softmax ? p.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
if (p_dropout > 0.0) {
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = params.b * params.h * 32;
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
params.seed = seed_t.data_ptr<int64_t>();
|
||||
params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
run_mha_fwd(params, stream);
|
||||
|
||||
at::Tensor out_padded = out;
|
||||
return {out, q_padded, k_padded, v_padded, softmax_lse, seed_t, offset_t, p};
|
||||
}
|
||||
|
||||
void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(!params.is_bf16, [&] {
|
||||
if (params.d <= 32) {
|
||||
run_mha_bwd_<elem_type, 32>(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_mha_bwd_<elem_type, 64>(params, stream, configure);
|
||||
} else if (params.d <= 96) {
|
||||
run_mha_bwd_<elem_type, 96>(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_mha_bwd_<elem_type, 128>(params, stream, configure);
|
||||
} else if (params.d <= 160) {
|
||||
run_mha_bwd_<elem_type, 160>(params, stream, configure);
|
||||
} else if (params.d <= 192) {
|
||||
run_mha_bwd_<elem_type, 192>(params, stream, configure);
|
||||
} else if (params.d <= 224) {
|
||||
run_mha_bwd_<elem_type, 224>(params, stream, configure);
|
||||
} else if (params.d <= 256) {
|
||||
run_mha_bwd_<elem_type, 256>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == at::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = sizes[0];
|
||||
const int seqlen_q = sizes[1];
|
||||
const int num_heads = sizes[2];
|
||||
const int head_size_og = dout.size(3);
|
||||
const int head_size = sizes[3];
|
||||
const int seqlen_k = k.size(1);
|
||||
const int num_heads_k = k.size(2);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size);
|
||||
} else {
|
||||
dq = at::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dk = at::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = at::empty_like(k);
|
||||
}
|
||||
|
||||
const at::Tensor& dout_padded = dout;
|
||||
|
||||
// bool loop = seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
at::Tensor dk_accum, dv_accum;
|
||||
if (loop) {
|
||||
dq_accum = at::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dk_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
// dv_accum = torch::empty({batch_size, num_heads_k, seqlen_k_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||
dv_expanded = at::empty({batch_size, seqlen_k, num_heads, head_size}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
Flash_bwd_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
seqlen_q, seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q, k, v, out,
|
||||
dout_padded, dq, dk_expanded, dv_expanded,
|
||||
nullptr,
|
||||
nullptr,
|
||||
loop ? dq_accum.data_ptr() : nullptr,
|
||||
// loop ? dk_accum.data_ptr() : nullptr,
|
||||
// loop ? dv_accum.data_ptr() : nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
|
||||
}
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset)
|
||||
{
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "FlashAttention only supports Ampere GPUs or newer.");
|
||||
// We will support Turing in the near future
|
||||
// TORCH_CHECK(is_sm90 || is_sm8x || is_sm75, "FlashAttention only supports Turing GPUs or newer.");
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf|| q_dtype == at::kBFloat16,
|
||||
"FlashAttention only support fp16 and bf16 data type");
|
||||
if (q_dtype == at::kBFloat16) {
|
||||
TORCH_CHECK(is_sm90 || is_sm8x, "bfloat16 is only supported on Ampere GPUs or newer");
|
||||
}
|
||||
TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype");
|
||||
TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype");
|
||||
TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype");
|
||||
TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype");
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt, "cu_seqlens_q must have dtype int32");
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt, "cu_seqlens_k must have dtype int32");
|
||||
|
||||
TORCH_CHECK(q.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(k.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(v.is_cuda(), "Input tensor must be on CUDA device");
|
||||
TORCH_CHECK(out.is_cuda(), "out tensor must be on CUDA device");
|
||||
TORCH_CHECK(dout.is_cuda(), "dout tensor must be on CUDA device");
|
||||
TORCH_CHECK(softmax_lse.is_cuda(), "softmax_lse tensor must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda(), "cu_seqlens_q must be on CUDA device");
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda(), "cu_seqlens_k must be on CUDA device");
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension");
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous");
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous");
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int total_q = sizes[0];
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int num_heads = sizes[1];
|
||||
const int head_size_og = dout.size(2);
|
||||
const int head_size = sizes[2];
|
||||
const int total_k = k.size(0);
|
||||
const int num_heads_k = k.size(1);
|
||||
TORCH_CHECK(batch_size > 0, "batch size must be positive");
|
||||
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
|
||||
TORCH_CHECK(head_size_og % 8 == 0, "head_size_og should be a multiple of 8, this is ensured by padding!");
|
||||
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
|
||||
if (head_size > 192) {
|
||||
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
|
||||
}
|
||||
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");
|
||||
|
||||
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
|
||||
const int head_size_rounded = round_multiple(head_size, 32);
|
||||
const int seqlen_q_rounded = round_multiple(max_seqlen_q, 128);
|
||||
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);
|
||||
|
||||
TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8");
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads_k, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size_og);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
at::Tensor dq, dk, dv;
|
||||
if (dq_.has_value()) {
|
||||
dq = dq_.value();
|
||||
TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q");
|
||||
TORCH_CHECK(dq.is_cuda(), "dq must be on CUDA device");
|
||||
TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension");
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
} else {
|
||||
dq = at::empty_like(q);
|
||||
}
|
||||
if (dk_.has_value()) {
|
||||
dk = dk_.value();
|
||||
TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q");
|
||||
TORCH_CHECK(dk.is_cuda(), "dk must be on CUDA device");
|
||||
TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension");
|
||||
CHECK_SHAPE(dk, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dk = at::empty_like(k);
|
||||
}
|
||||
if (dv_.has_value()) {
|
||||
dv = dv_.value();
|
||||
TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q");
|
||||
TORCH_CHECK(dv.is_cuda(), "dv must be on CUDA device");
|
||||
TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension");
|
||||
CHECK_SHAPE(dv, total_k, num_heads_k, head_size);
|
||||
} else {
|
||||
dv = at::empty_like(k);
|
||||
}
|
||||
|
||||
const at::Tensor& dout_padded = dout;
|
||||
|
||||
// bool loop = max_seqlen_k > blocksize_c;
|
||||
// TODO: change later, for now set to true for simplicity
|
||||
bool loop = true;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_accum;
|
||||
if (loop) {
|
||||
dq_accum = at::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded}, opts.dtype(at::kFloat));
|
||||
}
|
||||
|
||||
at::Tensor dk_expanded, dv_expanded;
|
||||
if (num_heads_k != num_heads) { // MQA / GQA
|
||||
dk_expanded = at::empty({total_k, num_heads, head_size}, opts);
|
||||
dv_expanded = at::empty({total_k, num_heads, head_size}, opts);
|
||||
} else {
|
||||
dk_expanded = dk;
|
||||
dv_expanded = dv;
|
||||
}
|
||||
|
||||
if( zero_tensors ) {
|
||||
dq.zero_();
|
||||
dk_expanded.zero_();
|
||||
dv_expanded.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
Flash_bwd_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
max_seqlen_q, max_seqlen_k,
|
||||
seqlen_q_rounded, seqlen_k_rounded,
|
||||
num_heads, num_heads_k,
|
||||
head_size, head_size_rounded,
|
||||
q, k, v, out,
|
||||
dout_padded, dq, dk_expanded, dv_expanded,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? dq_accum.data_ptr() : nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
auto launch = &run_mha_bwd;
|
||||
// launch(params, stream, /*configure=*/true);
|
||||
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
// For MQA/GQA we need to sum dK and dV across the groups
|
||||
if (num_heads_k != num_heads) {
|
||||
at::sum_out(dk, at::reshape(dk_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
||||
at::sum_out(dv, at::reshape(dv_expanded, {total_k, num_heads_k, num_heads / num_heads_k, head_size}), {2});
|
||||
}
|
||||
|
||||
return { dq, dk, dv, softmax_d };
|
||||
}
|
||||
} // namespace pytorch_fmha
|
||||
|
||||
#endif
|
||||
@ -0,0 +1,75 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
TORCH_API
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &out_, // total_q x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og
|
||||
const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x seqlen_q
|
||||
c10::optional<at::Tensor> &dq_, // batch_size x seqlen_q x num_heads x head_size
|
||||
c10::optional<at::Tensor> &dk_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
c10::optional<at::Tensor> &dv_, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool is_causal,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse, // b x h x s softmax logsumexp
|
||||
c10::optional<at::Tensor> &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
c10::optional<at::Tensor> &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q,
|
||||
const int max_seqlen_k, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
1581
aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h
Normal file
1581
aten/src/ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,364 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
template<bool Clear_dQaccum=true, typename Kernel_traits>
|
||||
__global__ void flash_bwd_dot_do_o_kernel(Flash_bwd_params params) {
|
||||
compute_dot_do_o<Clear_dQaccum, Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_clear_dkvaccum_kernel(Flash_bwd_params params) {
|
||||
clear_dKVaccum<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
|
||||
compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K>(params);
|
||||
#else
|
||||
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 5.3!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
|
||||
__global__ void flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel(Flash_bwd_params params) {
|
||||
compute_dq_dk_dv_seqq_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dq_kernel(Flash_bwd_params params) {
|
||||
convert_dQ<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void flash_bwd_convert_dkv_kernel(Flash_bwd_params params) {
|
||||
convert_dKV<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h);
|
||||
|
||||
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
|
||||
// a multiple of kBlockN, we'll need to apply mask in the loop.
|
||||
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
}
|
||||
kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd_seqq_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
const int num_n_block = (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN;
|
||||
dim3 grid_n(num_n_block, params.b, params.h_k);
|
||||
flash_bwd_clear_dkvaccum_kernel<Kernel_traits><<<grid_n, Kernel_traits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid_m(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_k as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1rowblock;
|
||||
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
|
||||
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
|
||||
if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
auto kernel_dkv = &flash_bwd_convert_dkv_kernel<Kernel_traits>;
|
||||
if (Kernel_traits::kSmemKVSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemKVSize));
|
||||
}
|
||||
kernel_dkv<<<grid_n, Kernel_traits::kNThreads, Kernel_traits::kSmemKVSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
//
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout>
|
||||
void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (configure) return;
|
||||
// dim3 grid(params.b, params.h);
|
||||
// const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
// dim3 grid_m(num_m_block, params.b, params.h);
|
||||
|
||||
// if (params.h == params.h_k) { // No multi-query or grouped-query attention (MQA/GQA)
|
||||
run_flash_bwd_seqk_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Kernel_traits, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
|
||||
// // We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// // for cu_seqlens_q as well.
|
||||
// const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
|
||||
// const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
// constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize;
|
||||
// BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
|
||||
// BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
|
||||
// BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
// // auto kernel = &flash_bwd_dq_dk_dv_loop_kernel<Kernel_traits, Is_dropout, IsCausalConst>;
|
||||
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
|
||||
// if (smem_size_dq_dk_dv >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
// }
|
||||
// kernel<<<grid_m, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
|
||||
// auto kernel_dq = &flash_bwd_convert_dq_kernel<Kernel_traits>;
|
||||
// if (Kernel_traits::kSmemdQSize >= 48 * 1024) {
|
||||
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
// kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize));
|
||||
// }
|
||||
// kernel_dq<<<grid_m, Kernel_traits::kNThreads, Kernel_traits::kSmemdQSize, stream>>>(params);
|
||||
// C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
//
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim32(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 32;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 2 * ((3 * 128 + 2 * 128) * Headdim + 2 * 128 * 128)) { // 104 KB
|
||||
if constexpr(!Is_dropout) { // We can afford more registers to keep V in registers
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
} else { // 96 KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 64;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// Changing AtomLayoutMdQ from 2 to 4 takes the same time
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// This is slightly faster. We want to split M more so we need fewer registers to store LSE.
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// This has a lot of register spilling
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// }
|
||||
}
|
||||
});
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 4, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, 2, 2, 2, true, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 16, 128, 4, 1, 4, 1, false, false, T>>(params, stream, configure);
|
||||
// M=128, N=64 is quite slow, I think because we need to read/write dQaccum twice as many times
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 2, 2, 2, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream, configure);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 4, 4, 2, 4, false, false, T>>(params, stream, configure);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim96(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 96;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
if constexpr(!Is_dropout) { // 92KB
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // 116 KB
|
||||
// This is faster for dropout since we don't have many registers to spare
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 128;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
// if (params.h == params.h_k) {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 32, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
// This is faster, in the case of sequence-parallel bwd (where we need fewer registers).
|
||||
// Out of these three, the 2nd one is slightly faster (2% faster than the first). Idk why.
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 2, 2, false, false, T>>(params, stream, configure);
|
||||
if (max_smem_per_block >= 144 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqk_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 128, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, false, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 128, 8, 2, 4, 4, false, false, T>>(params, stream, configure);
|
||||
|
||||
// run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// } else {
|
||||
// run_flash_bwd_seqq_parallel<Flash_bwd_kernel_traits<Headdim, 128, 64, 8, 4, 4, 4, false, false, T>>(params, stream, configure);
|
||||
// }
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim160(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 160;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 116 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim192(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 192;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 136 * 1024) {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, true, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim224(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 224;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 4, 4, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
if (max_smem_per_block >= 176 * 1024) { // H100
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream, configure);
|
||||
} else { // A100, we don't do double buffering to save smem
|
||||
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream, configure);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -0,0 +1,590 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
#include <cute/algorithm/gemm.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
#include <cutlass/numeric_conversion.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/block_info.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_A_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_A_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutA_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int MMA_M,
|
||||
class... Args,
|
||||
class TiledMMA>
|
||||
CUTE_HOST_DEVICE
|
||||
auto
|
||||
make_tiled_copy_C_warpcontiguousM(Copy_Atom<Args...> const& copy_atom,
|
||||
TiledMMA const& tiled_mma) {
|
||||
using TileShape_MNK = typename TiledMMA::TiledShape_MNK;
|
||||
using AtomShape_MNK = typename TiledMMA::AtomShape_MNK;
|
||||
constexpr int AtomShape_M = decltype(size<0>(AtomShape_MNK{}))::value;
|
||||
constexpr int kNWarps = decltype(size<0>(TileShape_MNK{}))::value / AtomShape_M;
|
||||
constexpr int MMAStride_M = MMA_M * AtomShape_M;
|
||||
auto t = make_tile(Layout<Shape<Int<AtomShape_M>, Int<kNWarps>>,
|
||||
Stride<_1, Int<MMAStride_M>> >{},
|
||||
// TODO: Shouldn't this be size<1>?
|
||||
make_layout(size<2>(TileShape_MNK{})));
|
||||
// if (cute::thread0()) {printf("make_tiled_copy_C_warpcontiguousM "); print(t); printf("\n"); }
|
||||
return make_tiled_copy_impl(copy_atom, tiled_mma.get_layoutC_TV(), t);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<bool Is_first, bool Check_inf=false, typename Tensor0, typename Tensor1, typename Tensor2>
|
||||
inline __device__ void softmax_rescale_o(Tensor0 &scores, Tensor1 &scores_max, Tensor1 &scores_sum,
|
||||
Tensor2 &acc_o, float softmax_scale_log2) {
|
||||
if (Is_first) {
|
||||
pytorch_flash::template reduce_max</*zero_init=*/true>(scores, scores_max);
|
||||
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
pytorch_flash::reduce_sum(scores, scores_sum);
|
||||
} else {
|
||||
Tensor scores_max_prev = make_fragment_like(scores_max);
|
||||
cute::copy(scores_max, scores_max_prev);
|
||||
pytorch_flash::template reduce_max</*zero_init=*/false>(scores, scores_max);
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_max); ++mi) {
|
||||
float scores_max_cur = !Check_inf
|
||||
? scores_max(mi)
|
||||
: (scores_max(mi) == -INFINITY ? 0.0f : scores_max(mi));
|
||||
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
|
||||
scores_sum(mi) *= scores_scale;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale; }
|
||||
}
|
||||
pytorch_flash::scale_apply_exp2(scores, scores_max, softmax_scale_log2);
|
||||
Tensor scores_sum_cur = make_fragment_like(scores_sum);
|
||||
pytorch_flash::reduce_sum(scores, scores_sum_cur);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(scores_sum); ++mi) { scores_sum(mi) += scores_sum_cur(mi); }
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename TiledCopy>
|
||||
inline __device__ void write_softmax_to_gmem(
|
||||
Tensor<Engine0, Layout0> const &tOrP, Tensor<Engine1, Layout1> &tPgP, TiledCopy gmem_tiled_copy_P
|
||||
) {
|
||||
// Reshape tOrP from (8, MMA_M, MMA_N) to (8, MMA_M * MMA_N)
|
||||
Layout l = tOrP.layout();
|
||||
Tensor tPrP = make_tensor(tOrP.data(), make_layout(get<0>(l), make_layout(get<1>(l), get<2>(l))));
|
||||
CUTE_STATIC_ASSERT_V(size<2>(tPgP) == _1{});
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tPrP) == size<1>(tPgP));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<1>(tPrP); ++mi) {
|
||||
cute::copy(gmem_tiled_copy_P, tPrP(_, mi), tPgP(_, mi, 0));
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) {
|
||||
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
using index_t = typename Kernel_traits::index_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
constexpr int kBlockM = Kernel_traits::kBlockM;
|
||||
constexpr int kBlockN = Kernel_traits::kBlockN;
|
||||
constexpr int kHeadDim = Kernel_traits::kHeadDim;
|
||||
constexpr int kNWarps = Kernel_traits::kNWarps;
|
||||
constexpr int MMA_M = kBlockM / decltype(size<0>(typename Kernel_traits::TiledMma::TiledShape_MNK{}))::value;
|
||||
|
||||
const BlockInfo</*Varlen=*/!Is_even_N> binfo(params, bidb);
|
||||
if (m_block * kBlockM >= binfo.actual_seqlen_q || binfo.actual_seqlen_k == 0) return;
|
||||
|
||||
int n_block_max = cute::ceil_div(binfo.actual_seqlen_k, kBlockN);
|
||||
if (Is_causal) {
|
||||
n_block_max = std::min(n_block_max, cute::ceil_div((m_block + 1) * kBlockM, kBlockN));
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) {
|
||||
// printf("m_block = %d, n_block_max = %d\n", m_block, n_block_max);
|
||||
// }
|
||||
}
|
||||
|
||||
// We iterate over the blocks in reverse order. This is because the last block is the only one
|
||||
// that needs masking when we read K and V from global memory. Moreover, iterating in reverse
|
||||
// might save us 1 register (we just need n_block instead of both n_block and n_block_max).
|
||||
|
||||
const index_t row_offset_q = binfo.q_offset(params.q_batch_stride, params.q_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
|
||||
// We move K and V to the last block.
|
||||
const index_t row_offset_k = binfo.k_offset(params.k_batch_stride, params.k_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.k_row_stride + (bidh / params.h_h_k_ratio) * params.k_head_stride;
|
||||
const index_t row_offset_v = binfo.k_offset(params.v_batch_stride, params.v_row_stride, bidb)
|
||||
+ (n_block_max - 1) * kBlockN * params.v_row_stride + (bidh / params.h_h_k_ratio) * params.v_head_stride;
|
||||
const index_t row_offset_p = ((bidb * params.h + bidh) * params.seqlen_q_rounded
|
||||
+ m_block * kBlockM) * params.seqlen_k_rounded + (n_block_max - 1) * kBlockN;
|
||||
|
||||
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.q_row_stride, _1{}));
|
||||
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.k_row_stride, _1{}));
|
||||
Tensor gV = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.v_ptr) + row_offset_v),
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{},
|
||||
make_stride(params.v_row_stride, _1{}));
|
||||
Tensor gP = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.p_ptr) + row_offset_p),
|
||||
Shape<Int<kBlockM>, Int<kBlockN>>{},
|
||||
make_stride(params.seqlen_k_rounded, _1{}));
|
||||
|
||||
Tensor sQ = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)),
|
||||
typename Kernel_traits::SmemLayoutQ{});
|
||||
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
|
||||
Tensor sK = make_tensor(sQ.data() + (Kernel_traits::Share_Q_K_smem ? 0 : size(sQ)),
|
||||
typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{});
|
||||
Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{});
|
||||
Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV;
|
||||
auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx);
|
||||
typename Kernel_traits::GmemTiledCopyP gmem_tiled_copy_P;
|
||||
auto gmem_thr_copy_P = gmem_tiled_copy_P.get_thread_slice(tidx);
|
||||
|
||||
Tensor tQgQ = gmem_thr_copy_QKV.partition_S(gQ);
|
||||
Tensor tQsQ = gmem_thr_copy_QKV.partition_D(sQ);
|
||||
Tensor tKgK = gmem_thr_copy_QKV.partition_S(gK); // (KCPY, KCPY_N, KCPY_K)
|
||||
Tensor tKsK = gmem_thr_copy_QKV.partition_D(sK);
|
||||
Tensor tVgV = gmem_thr_copy_QKV.partition_S(gV); // (VCPY, VCPY_N, VCPY_K)
|
||||
Tensor tVsV = gmem_thr_copy_QKV.partition_D(sV);
|
||||
Tensor tPgP = gmem_thr_copy_P.partition_D(gP);
|
||||
|
||||
typename Kernel_traits::TiledMma tiled_mma;
|
||||
auto thr_mma = tiled_mma.get_thread_slice(tidx);
|
||||
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
|
||||
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
|
||||
Tensor tOrVt = thr_mma.partition_fragment_B(sVtNoSwizzle); // (MMA, MMA_K,MMA_N)
|
||||
|
||||
Tensor acc_o = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kHeadDim>>{}); // MMA, MMA_M, MMA_K
|
||||
|
||||
//
|
||||
// Copy Atom retiling
|
||||
//
|
||||
|
||||
auto smem_tiled_copy_Q = make_tiled_copy_A(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_Q = smem_tiled_copy_Q.get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_Q = make_tiled_copy_A_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtom{}, tiled_mma).get_thread_slice(tidx);
|
||||
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
|
||||
Tensor tSsQ = smem_thr_copy_Q.partition_S(sQ);
|
||||
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
|
||||
|
||||
auto smem_tiled_copy_K = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtom{}, tiled_mma);
|
||||
auto smem_thr_copy_K = smem_tiled_copy_K.get_thread_slice(tidx);
|
||||
Tensor tSsK = smem_thr_copy_K.partition_S(sK);
|
||||
|
||||
auto smem_tiled_copy_V = make_tiled_copy_B(typename Kernel_traits::SmemCopyAtomTransposed{}, tiled_mma);
|
||||
auto smem_thr_copy_V = smem_tiled_copy_V.get_thread_slice(tidx);
|
||||
Tensor tOsVt = smem_thr_copy_V.partition_S(sVt);
|
||||
|
||||
// TODO: this might need to change if we change the mma instruction in SM70
|
||||
Tensor scores_max = make_tensor<ElementAccum>(Shape<Int<2 * size<1>(acc_o)>>{});
|
||||
Tensor scores_sum = make_fragment_like(scores_max);
|
||||
|
||||
//
|
||||
// PREDICATES
|
||||
//
|
||||
|
||||
// // Allocate predicate tensors for m and n
|
||||
// Tensor tQpQ = make_tensor<bool>(make_shape(size<1>(tQsQ), size<2>(tQsQ)), Stride<_1,_0>{});
|
||||
// Tensor tKVpKV = make_tensor<bool>(make_shape(size<1>(tKsK), size<2>(tKsK)), Stride<_1,_0>{});
|
||||
|
||||
// Construct identity layout for sQ and sK
|
||||
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor cKV = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
|
||||
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
|
||||
// if (cute::thread0()) {
|
||||
// print(tScQ.layout()); printf("\n");
|
||||
// for (int i = 0; i < size(tScQ); ++i) {
|
||||
// printf("%d ", get<0>(tScQ(i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// for (int i = 0; i < size(tScQ); ++i) {
|
||||
// printf("%d ", get<1>(tScQ(i)));
|
||||
// }
|
||||
// printf("\n");
|
||||
// }
|
||||
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tQcQ = gmem_thr_copy_QKV.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
|
||||
|
||||
// Allocate predicate tensors for k
|
||||
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
|
||||
Tensor tKVpKV = make_tensor<bool>(make_shape(size<2>(tKsK)));
|
||||
|
||||
// Set predicates for k bounds
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(0, 0, k)) < params.d; }
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(0, 0, k)) < params.d; }
|
||||
}
|
||||
|
||||
// Prologue
|
||||
|
||||
Tensor tQrQ = make_fragment_like(tQgQ);
|
||||
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, Is_even_K>(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ,
|
||||
binfo.actual_seqlen_q - m_block * kBlockM);
|
||||
if (Kernel_traits::Is_Q_in_regs) { cute::cp_async_fence(); }
|
||||
|
||||
// // Copy rmem to smem
|
||||
// // copy(tQrQ, tQsQ);
|
||||
// pytorch_flash::cp_async_wait<0>();
|
||||
// __syncthreads();
|
||||
// // if (cute::thread(1, 0)) { print(tQsQ); }
|
||||
// // Tensor sQNoSwizzle = make_tensor(make_smem_ptr(reinterpret_cast<Element *>(smem_)), typename Kernel_traits::SmemLayoutQNoSwizzle{});
|
||||
// // if (cute::thread0()) { print(sQNoSwizzle); }
|
||||
|
||||
if (Kernel_traits::Share_Q_K_smem) {
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
int n_block = n_block_max - 1;
|
||||
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
|
||||
pytorch_flash::copy<Is_even_N, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV,
|
||||
binfo.actual_seqlen_k - n_block * kBlockN);
|
||||
cute::cp_async_fence();
|
||||
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z < 2) { print(tKgK); }
|
||||
// __syncthreads();
|
||||
|
||||
if (Kernel_traits::Is_Q_in_regs && !Kernel_traits::Share_Q_K_smem) {
|
||||
pytorch_flash::cp_async_wait<1>();
|
||||
__syncthreads();
|
||||
Tensor tSrQ_copy_view = smem_thr_copy_Q.retile_D(tSrQ);
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tSsQ) == size<1>(tSrQ_copy_view)); // M
|
||||
cute::copy(smem_tiled_copy_Q, tSsQ, tSrQ_copy_view);
|
||||
}
|
||||
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seeds);
|
||||
*params.extragraph_offset = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
unsigned long long seed = std::get<0>(seeds);
|
||||
unsigned long long offset = std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32;
|
||||
|
||||
clear(acc_o);
|
||||
|
||||
// For performance reason, we separate out two kinds of iterations:
|
||||
// those that need masking on S, and those that don't.
|
||||
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
|
||||
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
|
||||
// We will have at least 1 "masking" iteration.
|
||||
|
||||
constexpr int n_masking_steps = Is_causal ? cute::ceil_div(kBlockM, kBlockN) : 1;
|
||||
#pragma unroll
|
||||
for (int masking_step = 0; masking_step < n_masking_steps; ++masking_step, --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
|
||||
// Advance gV
|
||||
if (masking_step > 0) {
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
} else {
|
||||
// Clear the smem tiles to account for predicated off loads
|
||||
pytorch_flash::copy<Is_even_N, Is_even_K, /*Clear_OOB_MN=*/true>(
|
||||
gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
|
||||
);
|
||||
}
|
||||
cute::cp_async_fence();
|
||||
|
||||
pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
// if (cute::thread0()) { print(acc_s); }
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
// We don't put the masking before the matmul S = Q K^T because we don't clear sK
|
||||
// for rows outside actual_seqlen_k. So those rows could have Inf / NaN, and the matmul
|
||||
// can produce Inf / NaN.
|
||||
if (!Is_causal) {
|
||||
if (!Is_even_N) { pytorch_flash::apply_mask(scores, binfo.actual_seqlen_k - n_block * kBlockN); }
|
||||
} else {
|
||||
// Tensor caccS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{}); // (BLK_M,BLK_N) -> (blk_m,blk_n)
|
||||
// Tensor taccScS = thr_mma.partition_C(caccS); // (MMA,MMA_M,MMA_N)
|
||||
// static_assert(decltype(size<0>(taccScS))::value == 4);
|
||||
// // Convert to ((2, 2), MMA_M, MMA_N) then take only the row indices.
|
||||
// Tensor idx_row = logical_divide(taccScS, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
// Tensor idx_rowcol = make_tensor(taccScS.data(), pytorch_flash::convert_layout_acc_rowcol(taccScS.layout()));
|
||||
// pytorch_flash::apply_mask_causal_w_idx(scores, idx_rowcol, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM);
|
||||
// Idk why it's get<1> and not get<0> of the stride.
|
||||
// if (cute::thread0()) { print(idx_row.layout()); print(stride<1>(idx_row)); printf("stride = %d \n", get<1>(stride<1>(idx_row))); }
|
||||
// I can't get the stride from idx_row
|
||||
pytorch_flash::apply_mask_causal(scores, n_block * kBlockN, binfo.actual_seqlen_k,
|
||||
// m_block * kBlockM + get<0>(idx_row(0)),
|
||||
m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4,
|
||||
kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * 16, kNWarps * 16);
|
||||
// m_block * kBlockM + (tidx / 32) * (kBlockM / kNWarps), 16);
|
||||
}
|
||||
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// TODO: when we have key_padding_mask we'll need to Check_inf
|
||||
masking_step == 0
|
||||
? softmax_rescale_o</*Is_first=*/true, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2)
|
||||
: softmax_rescale_o</*Is_first=*/false, /*Check_inf=*/Is_causal>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
// Convert scores from fp32 to fp16/bf16
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
// if (cute::thread0()) { print(tOrP); }
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
// if (cute::thread0()) { print(scores); }
|
||||
|
||||
// This check is at the end of the loop since we always have at least 1 iteration
|
||||
if (n_masking_steps > 1 && n_block <= 0) {
|
||||
--n_block;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// These are the iterations where we don't need masking on S
|
||||
for (; n_block >= 0; --n_block) {
|
||||
Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // (MMA=4, MMA_M, MMA_N)
|
||||
clear(acc_s);
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
// Advance gV
|
||||
tVgV.data() = tVgV.data() + (-int(kBlockN * params.v_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV);
|
||||
cute::cp_async_fence();
|
||||
|
||||
pytorch_flash::gemm</*A_in_regs=*/Kernel_traits::Is_Q_in_regs>(
|
||||
acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K,
|
||||
smem_thr_copy_Q, smem_thr_copy_K
|
||||
);
|
||||
|
||||
pytorch_flash::cp_async_wait<0>();
|
||||
__syncthreads();
|
||||
if (n_block > 0) {
|
||||
// Advance gK
|
||||
tKgK.data() = tKgK.data() + (-int(kBlockN * params.k_row_stride));
|
||||
pytorch_flash::copy</*Is_even_MN=*/true, Is_even_K>(gmem_tiled_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV);
|
||||
// This cp_async_fence needs to be in the if block, otherwise the synchronization
|
||||
// isn't right and we get race conditions.
|
||||
cute::cp_async_fence();
|
||||
}
|
||||
|
||||
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
|
||||
Tensor scores = make_tensor(acc_s.data(), pytorch_flash::convert_layout_acc_rowcol(acc_s.layout()));
|
||||
softmax_rescale_o</*Is_first=*/false>(scores, scores_max, scores_sum, acc_o, params.scale_softmax_log2);
|
||||
|
||||
Tensor rP = pytorch_flash::convert_type<Element>(scores);
|
||||
// Reshape rP from (nrow=(2, MMA_M), ncol=(2, MMA_N)) to ((2, 2, 2), MMA_M, MMA_N / 2)
|
||||
// if using m16n8k16 or ((2, 2, 1), MMA_M, MMA_N) if using m16n8k8.
|
||||
Tensor tOrP = make_tensor(rP.data(), pytorch_flash::convert_layout_rowcol_Aregs<Kernel_traits::TiledMma>(rP.layout()));
|
||||
uint32_t block_row_idx = m_block * (kBlockM / 16) + tidx / 32;
|
||||
uint32_t block_col_idx = n_block * (kBlockN / 32);
|
||||
if (Return_softmax) {
|
||||
Tensor tOrP_copy = make_fragment_like(tOrP);
|
||||
cute::copy(tOrP, tOrP_copy);
|
||||
pytorch_flash::apply_dropout</*encode_dropout_in_sign_bit=*/true>(
|
||||
tOrP_copy, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps
|
||||
);
|
||||
pytorch_flash::write_softmax_to_gmem(tOrP_copy, tPgP, gmem_tiled_copy_P);
|
||||
tPgP.data() = tPgP.data() + (-kBlockN);
|
||||
}
|
||||
if (Is_dropout) {
|
||||
pytorch_flash::apply_dropout(tOrP, params.p_dropout_in_uint8_t, seed, offset,
|
||||
block_row_idx, block_col_idx, kNWarps);
|
||||
}
|
||||
|
||||
pytorch_flash::gemm_A_in_regs(acc_o, tOrP, tOrVt, tOsVt, tiled_mma, smem_tiled_copy_V, smem_thr_copy_V);
|
||||
}
|
||||
|
||||
// Epilogue
|
||||
|
||||
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
|
||||
Tensor acc_o_rowcol = make_tensor(acc_o.data(), pytorch_flash::convert_layout_acc_rowcol(acc_o.layout()));
|
||||
Tensor lse = make_fragment_like(scores_sum);
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
|
||||
float sum = scores_sum(mi);
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
lse(mi) = (sum == 0.f || sum != sum) ? INFINITY : scores_max(mi) * params.scale_softmax + __logf(sum);
|
||||
float scale = !Is_dropout ? inv_sum : inv_sum * params.rp_dropout;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
|
||||
}
|
||||
|
||||
// if (cute::thread0()) { print(acc_o_rowcol); }
|
||||
|
||||
// Convert acc_o from fp32 to fp16/bf16
|
||||
Tensor rO = pytorch_flash::convert_type<Element>(acc_o);
|
||||
Tensor sO = make_tensor(sQ.data(), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
|
||||
// Partition sO to match the accumulator partitioning
|
||||
auto smem_tiled_copy_O = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma);
|
||||
auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(tidx);
|
||||
// auto smem_thr_copy_O = make_tiled_copy_C_warpcontiguousM<MMA_M>(typename Kernel_traits::SmemCopyAtomO{}, tiled_mma).get_thread_slice(tidx);
|
||||
Tensor taccOrO = smem_thr_copy_O.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
|
||||
Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N)
|
||||
|
||||
// sO has the same size as sQ, so we don't need to sync here.
|
||||
if (Kernel_traits::Share_Q_K_smem) { __syncthreads(); }
|
||||
|
||||
cute::copy(smem_tiled_copy_O, taccOrO, taccOsO);
|
||||
|
||||
const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb)
|
||||
+ m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
|
||||
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.o_ptr) + row_offset_o),
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{},
|
||||
make_stride(params.o_row_stride, _1{}));
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
|
||||
Shape<Int<kBlockM>>{}, Stride<_1>{});
|
||||
|
||||
typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O;
|
||||
auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx);
|
||||
Tensor tOsO = gmem_thr_copy_O.partition_S(sO); // ((Atom,AtomNum),ATOM_M,ATOM_N)
|
||||
Tensor tOgO = gmem_thr_copy_O.partition_D(gO);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Tensor tOrO = make_tensor<Element>(shape(tOgO));
|
||||
cute::copy(gmem_tiled_copy_O, tOsO, tOrO);
|
||||
|
||||
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDim>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
Tensor taccOcO = thr_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K)
|
||||
static_assert(decltype(size<0>(taccOcO))::value == 4);
|
||||
// Convert to ((2, 2), MMA_M, MMA_K) then take only the row indices.
|
||||
Tensor taccOcO_row = logical_divide(taccOcO, Shape<_2>{})(make_coord(0, _), _, 0);
|
||||
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
|
||||
if (get<1>(taccOcO_row(0)) == 0) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size(lse); ++mi) {
|
||||
const int row = get<0>(taccOcO_row(mi));
|
||||
if (row < binfo.actual_seqlen_q - m_block * kBlockM) { gLSE(row) = lse(mi); }
|
||||
}
|
||||
}
|
||||
|
||||
// Construct identity layout for sO
|
||||
Tensor cO = make_identity_tensor(make_shape(size<0>(sO), size<1>(sO))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
|
||||
// Repeat the partitioning with identity layouts
|
||||
Tensor tOcO = gmem_thr_copy_O.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
|
||||
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgO)));
|
||||
if (!Is_even_K) {
|
||||
#pragma unroll
|
||||
for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(0, 0, k)) < params.d; }
|
||||
}
|
||||
// Clear_OOB_K must be false since we don't want to write zeros to gmem
|
||||
pytorch_flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
|
||||
gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, binfo.actual_seqlen_q - m_block * kBlockM
|
||||
);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax, typename Params>
|
||||
inline __device__ void compute_attn(const Params ¶ms) {
|
||||
const int m_block = blockIdx.x;
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.y;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.z;
|
||||
|
||||
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
|
||||
// them to have the same number of threads or have to traverse the attention matrix
|
||||
// in the same order.
|
||||
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
|
||||
// (within a warp). We use the subsequence to store the location of the 16 x 32 blocks within
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
|
||||
pytorch_flash::compute_attn_1rowblock<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params, bidb, bidh, m_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace flash
|
||||
@ -0,0 +1,259 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_kernel.h>
|
||||
|
||||
namespace pytorch_flash {
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K, bool Return_softmax>
|
||||
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
pytorch_flash::compute_attn<Kernel_traits, Is_dropout, Is_causal, Is_even_N, Is_even_K, Return_softmax>(params);
|
||||
#else
|
||||
printf("FATAL: FlashAttention requires to be build with sm80-sm90, but was built for < 5.3!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr size_t smem_size = Kernel_traits::kSmemSize;
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
|
||||
const int num_m_block = (params.seqlen_q + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM;
|
||||
dim3 grid(num_m_block, params.b, params.h);
|
||||
// We also use is_even_N to set Unpadded in the BlockInfo constructor, so we need to check
|
||||
// for cu_seqlens_q as well.
|
||||
const bool is_even_N = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0;
|
||||
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
|
||||
const bool return_softmax = params.p_ptr != nullptr;
|
||||
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
|
||||
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
|
||||
BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] {
|
||||
// Will only return softmax if dropout, to reduce compilation time.
|
||||
auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, IsEvenKConst, ReturnSoftmaxConst && Is_dropout>;
|
||||
// auto kernel = &flash_fwd_kernel<Kernel_traits, Is_dropout, Is_causal, IsEvenNConst, true, ReturnSoftmaxConst && Is_dropout>;
|
||||
if (smem_size >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size);
|
||||
// printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm);
|
||||
kernel<<<grid, Kernel_traits::kNThreads, smem_size, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 32;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 64;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower
|
||||
// Using block size (64 x 256) is 27% slower for seqlen=2k
|
||||
// Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 96;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// These two are always slower
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 128, 128, 4, true, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<96, 64, 128, 4, true, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 128;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 1st ones are good for H100, A100
|
||||
// 2nd one is good for A6000 bc we get slightly better occupancy
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, true, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 160;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor > 0;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, H100, 128 x 32 is the fastest.
|
||||
// For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square),
|
||||
// and 128 x 64 with 8 warps is the fastest for non-causal.
|
||||
if (is_sm8x) {
|
||||
if constexpr(!Is_causal) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, true, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 192;
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if constexpr(!Is_dropout) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 128, 4, false, T>>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 8, false, T>>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 224;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_block = %d\n", max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32.
|
||||
// If we have N = 32, there are only 1024 elements to load at once, where each load
|
||||
// is 8 elements. This means we can only use 128 threads and not 256 threads.
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
constexpr int Headdim = 256;
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
int max_smem_per_sm, max_smem_per_block;
|
||||
cudaError status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device);
|
||||
status_ = cudaDeviceGetAttribute(
|
||||
&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device);
|
||||
// printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block);
|
||||
BOOL_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
|
||||
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
|
||||
// For A100, we want to run with 128 x 64 (128KB smem).
|
||||
// For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM.
|
||||
if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 64, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
} else {
|
||||
run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 64, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
}
|
||||
// 64 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 64, 32, 4, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
// 96 KB
|
||||
// run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 32, 8, false, false, T>, Is_dropout, Is_causal>(params, stream);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,214 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#ifdef OLD_GENERATOR_PATH
|
||||
#include <ATen/CUDAGeneratorImpl.h>
|
||||
#else
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_utils.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
// size_t qkv_stride_in_elts;
|
||||
// size_t qkv_stride_in_bytes;
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
uint32_t q_row_stride_in_elts;
|
||||
uint32_t k_row_stride_in_elts;
|
||||
uint32_t v_row_stride_in_elts;
|
||||
uint32_t q_head_stride_in_elts;
|
||||
uint32_t k_head_stride_in_elts;
|
||||
uint32_t v_head_stride_in_elts;
|
||||
|
||||
// The number of heads.
|
||||
int h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FMHA_fprop_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
// size_t o_stride_in_elts;
|
||||
// size_t o_stride_in_bytes;
|
||||
uint32_t o_row_stride_in_elts;
|
||||
uint32_t o_head_stride_in_elts;
|
||||
uint32_t o_tmp_row_stride_in_elts;
|
||||
uint32_t o_tmp_head_stride_in_elts;
|
||||
|
||||
// The pointer to the O_tmp matrix, which holds O intermediate value during
|
||||
// the loop;
|
||||
void *__restrict__ o_tmp_ptr;
|
||||
|
||||
// The pointer to the S matrix.
|
||||
void * __restrict__ s_ptr;
|
||||
// The stride between rows of the S matrix.
|
||||
// int64_t s_stride_in_bytes;
|
||||
uint32_t s_stride_in_bytes;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_bmm1f;
|
||||
uint32_t scale_bmm1;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
uint32_t p_dropout_in_uint;
|
||||
uint16_t p_dropout_in_uint16_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_bmm1_rp_dropout;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout), in half2.
|
||||
uint32_t scale_dropout;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
int64_t * extragraph_offset;
|
||||
int64_t * seed;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
|
||||
int num_splits; // How many SMs per attention matrix.
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FMHA_dgrad_params : public FMHA_fprop_params {
|
||||
|
||||
// The dQKV matrices.
|
||||
void *__restrict__ dq_ptr;
|
||||
void *__restrict__ dk_ptr;
|
||||
void *__restrict__ dv_ptr;
|
||||
|
||||
// // To accumulate dK and dV in case we're splitting the bwd along seqlen_q dimension
|
||||
// void *__restrict__ dk_accum_ptr;
|
||||
// void *__restrict__ dv_accum_ptr;
|
||||
|
||||
// The stride between rows of the dQ, dK and dV matrices.
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
uint32_t dq_row_stride_in_elts;
|
||||
uint32_t dk_row_stride_in_elts;
|
||||
uint32_t dv_row_stride_in_elts;
|
||||
uint32_t dq_head_stride_in_elts;
|
||||
uint32_t dk_head_stride_in_elts;
|
||||
uint32_t dv_head_stride_in_elts;
|
||||
|
||||
// The dO matrix. We assume it is contiguous.
|
||||
void * __restrict__ do_ptr;
|
||||
|
||||
// The pointer to the softmax d sum.
|
||||
void * __restrict__ dsoftmax_sum;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_params>
|
||||
struct Launch_params{
|
||||
Launch_params(cudaDeviceProp * props_,
|
||||
cudaStream_t stream_,
|
||||
bool is_dropout_,
|
||||
bool return_softmax_)
|
||||
: elts_per_thread(0)
|
||||
, props(props_)
|
||||
, stream(stream_)
|
||||
, is_dropout(is_dropout_)
|
||||
, return_softmax(return_softmax_) {
|
||||
}
|
||||
|
||||
size_t elts_per_thread;
|
||||
|
||||
cudaDeviceProp * props;
|
||||
|
||||
cudaStream_t stream;
|
||||
|
||||
bool is_dropout;
|
||||
bool return_softmax;
|
||||
|
||||
Kernel_params params;
|
||||
int num_full_heads;
|
||||
int num_main_groups;
|
||||
int heads_last_wave;
|
||||
int main_steps;
|
||||
int rest_steps;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params);
|
||||
|
||||
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure);
|
||||
|
||||
void run_fmha_block_fp16_sm80(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
||||
|
||||
void run_fmha_block_dgrad_fp16_sm80(const FMHA_dgrad_params ¶ms, cudaStream_t stream);
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,540 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <cstdint>
|
||||
#include <tuple>
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/scalar_tensor.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void set_params_fprop(FMHA_fprop_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t h,
|
||||
const size_t d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
at::Tensor out,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *o_tmp_d,
|
||||
void *s_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal,
|
||||
int num_splits) {
|
||||
|
||||
Data_type data_type = !(q.dtype() == at::kBFloat16) ? DATA_TYPE_FP16 : DATA_TYPE_BF16;
|
||||
|
||||
// Reset the parameters
|
||||
params = {};
|
||||
|
||||
params.is_bf16 = q.dtype() == at::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
params.q_row_stride_in_elts = q.stride(0);
|
||||
params.k_row_stride_in_elts = k.stride(0);
|
||||
params.v_row_stride_in_elts = v.stride(0);
|
||||
params.q_head_stride_in_elts = q.stride(1);
|
||||
params.k_head_stride_in_elts = k.stride(1);
|
||||
params.v_head_stride_in_elts = v.stride(1);
|
||||
params.o_ptr = out.data_ptr();
|
||||
params.o_row_stride_in_elts = out.stride(0);
|
||||
params.o_head_stride_in_elts = out.stride(1);
|
||||
params.o_tmp_ptr = o_tmp_d;
|
||||
params.o_tmp_row_stride_in_elts = h * d;
|
||||
params.o_tmp_head_stride_in_elts = d;
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
|
||||
// S = softmax(P)
|
||||
params.s_ptr = s_d;
|
||||
params.s_stride_in_bytes = get_size_in_bytes(b * h * seqlen_k, data_type);
|
||||
|
||||
// Softmax sum
|
||||
params.softmax_lse_ptr = softmax_lse_d;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.d = d;
|
||||
|
||||
// Set the different scale values.
|
||||
// const float scale_bmm1 = 1.f / sqrtf(d);
|
||||
const float scale_bmm1 = softmax_scale;
|
||||
|
||||
params.scale_bmm1f = scale_bmm1;
|
||||
set_alpha(params.scale_bmm1, scale_bmm1, data_type);
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1f;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
set_alpha(params.scale_dropout, params.rp_dropout, data_type);
|
||||
|
||||
params.is_causal = is_causal;
|
||||
params.num_splits = num_splits;
|
||||
}
|
||||
|
||||
void set_params_dgrad(FMHA_dgrad_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t h,
|
||||
const size_t d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
const at::Tensor out,
|
||||
at::Tensor dq,
|
||||
at::Tensor dk,
|
||||
at::Tensor dv,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *dq_tmp_d,
|
||||
void *do_packed_d,
|
||||
void *softmax_lse_d,
|
||||
void *dsoftmax_sum_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal,
|
||||
int num_splits) {
|
||||
|
||||
set_params_fprop(params,
|
||||
b, seqlen_q, seqlen_k, h, d,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q_d,
|
||||
cu_seqlens_k_d,
|
||||
dq_tmp_d, // Reusing the o_tmp_ptr variable to store dq_tmp
|
||||
nullptr,
|
||||
softmax_lse_d,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.dq_ptr = dq.data_ptr();
|
||||
params.dk_ptr = dk.data_ptr();
|
||||
params.dv_ptr = dv.data_ptr();
|
||||
params.dq_row_stride_in_elts = dq.stride(0);
|
||||
params.dk_row_stride_in_elts = dk.stride(0);
|
||||
params.dv_row_stride_in_elts = dv.stride(0);
|
||||
params.dq_head_stride_in_elts = dq.stride(1);
|
||||
params.dk_head_stride_in_elts = dk.stride(1);
|
||||
params.dv_head_stride_in_elts = dv.stride(1);
|
||||
params.do_ptr = do_packed_d;
|
||||
|
||||
// Softmax sum
|
||||
params.dsoftmax_sum = dsoftmax_sum_d;
|
||||
}
|
||||
|
||||
void run_fmha_fwd(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
if (launch_params.params.d <= 32) {
|
||||
run_fmha_fwd_hdim32(launch_params);
|
||||
} else if (launch_params.params.d <= 64) {
|
||||
run_fmha_fwd_hdim64(launch_params);
|
||||
} else if (launch_params.params.d <= 128) {
|
||||
run_fmha_fwd_hdim128(launch_params);
|
||||
}
|
||||
}
|
||||
// The tensor `out` will get populated the output attention
|
||||
// First return value is softmax_logsumexp
|
||||
// Second return value is the random generator state
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &out,
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
const int num_splits) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || ((is_sm8x || is_sm90) && q_dtype == at::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int blocksize_c = head_size > 64 ? 128 : 256;
|
||||
// Need to round max_seqlen_k to multiples of blocksize_c
|
||||
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
|
||||
if( max_seqlen_k_ <= 128 ) {
|
||||
max_seqlen_k = 128;
|
||||
} else if( max_seqlen_k_ <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
at::cuda::CUDAGuard device_guard{q.device()};
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
// auto o = torch::empty({ total_q, num_heads, head_size }, opts);
|
||||
|
||||
at::Tensor o_tmp;
|
||||
if (loop) { o_tmp = at::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
// auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
|
||||
|
||||
at::Tensor flash_softmax;
|
||||
if (return_softmax) {flash_softmax = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
out.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {flash_softmax.zero_();}
|
||||
}
|
||||
|
||||
set_params_fprop(launch_params.params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, out,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? flash_softmax.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
// We use a custom RNG that increases the offset by batch_size * nheads * 32.
|
||||
int64_t counter_offset = launch_params.params.b * launch_params.params.h * 32;
|
||||
|
||||
// We want to checkpoint and save the RNG state for backward if dropout
|
||||
// We get the default generator and return the seed and offset which will
|
||||
// be used in the backward function
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(c10::nullopt, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
at::Tensor seed_t, offset_t;
|
||||
if (is_dropout) {
|
||||
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
at::PhiloxCudaState philox_state = gen->philox_cuda_state(counter_offset);
|
||||
if (at::cuda::currentStreamCaptureStatus() == at::cuda::CaptureStatus::None) {
|
||||
auto [seed, offset] = at::cuda::philox::unpack(philox_state);
|
||||
seed_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(seed)), at::dtype(at::kLong));
|
||||
offset_t = at::scalar_tensor(at::Scalar(static_cast<int64_t>(offset)), at::dtype(at::kLong));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
launch_params.params.seed = seed_t.data_ptr<int64_t>();
|
||||
launch_params.params.extragraph_offset = offset_t.data_ptr<int64_t>();
|
||||
}
|
||||
launch_params.params.philox_args = philox_state;
|
||||
} else {
|
||||
if (at::cuda::currentStreamCaptureStatus() != at::cuda::CaptureStatus::None) {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA));
|
||||
} else {
|
||||
seed_t = at::empty({}, at::dtype(at::kLong));
|
||||
offset_t = at::empty({}, at::dtype(at::kLong));
|
||||
}
|
||||
}
|
||||
|
||||
run_fmha_fwd(launch_params);
|
||||
return {softmax_lse, seed_t, offset_t, flash_softmax};
|
||||
}
|
||||
|
||||
void run_fmha_bwd(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
if (params.d <= 32) {
|
||||
run_fmha_bwd_hdim32(params, stream, configure);
|
||||
} else if (params.d <= 64) {
|
||||
run_fmha_bwd_hdim64(params, stream, configure);
|
||||
} else if (params.d <= 128) {
|
||||
run_fmha_bwd_hdim128(params, stream, configure);
|
||||
}
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp
|
||||
at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int num_splits,
|
||||
at::Tensor philox_seed,
|
||||
at::Tensor philox_offset
|
||||
) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
bool is_sm90 = dprops->major == 9 && dprops->minor == 0;
|
||||
TORCH_CHECK(is_sm90 || is_sm8x || is_sm75);
|
||||
auto launch = &run_fmha_bwd;
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
|
||||
TORCH_CHECK(q_dtype == at::kHalf || ((is_sm8x || is_sm90) && q_dtype == at::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(out.dtype() == q_dtype);
|
||||
TORCH_CHECK(dout.dtype() == q_dtype);
|
||||
TORCH_CHECK(dq.dtype() == q_dtype);
|
||||
TORCH_CHECK(dk.dtype() == q_dtype);
|
||||
TORCH_CHECK(dv.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(out.is_cuda());
|
||||
TORCH_CHECK(dout.is_cuda());
|
||||
TORCH_CHECK(softmax_lse_.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(dout.is_contiguous());
|
||||
TORCH_CHECK(dq.stride(-1) == 1);
|
||||
TORCH_CHECK(dk.stride(-1) == 1);
|
||||
TORCH_CHECK(dv.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_q.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
if (head_size > 64) { // TODO: eventually we should support SM86 and SM70 with d=128 as well
|
||||
TORCH_CHECK(is_sm80 || is_sm90);
|
||||
}
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(out, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dout, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dq, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(dk, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(dv, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int blocksize_c = (head_size > 64 || (is_sm75 && head_size > 32)) ? 128 : 256;
|
||||
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
|
||||
if( max_seqlen_k_ <= 128 ) {
|
||||
max_seqlen_k = 128;
|
||||
} else if( max_seqlen_k_ <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
|
||||
|
||||
// It's possible the softmax_lse_ from the fwd has a different length since blocksize_c could be different.
|
||||
auto softmax_lse = softmax_lse_.index({at::indexing::Slice(), at::indexing::Slice(), at::indexing::Slice(at::indexing::None, max_seqlen_q)}).contiguous();
|
||||
|
||||
auto opts = q.options();
|
||||
auto softmax_d = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
at::Tensor dq_tmp;
|
||||
if (loop) { dq_tmp = at::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
dq.zero_();
|
||||
dk.zero_();
|
||||
dv.zero_();
|
||||
softmax_d.zero_();
|
||||
}
|
||||
|
||||
FMHA_dgrad_params params;
|
||||
|
||||
set_params_dgrad(params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v, out,
|
||||
dq, dk, dv,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
loop ? dq_tmp.data_ptr() : nullptr,
|
||||
dout.data_ptr(),
|
||||
softmax_lse.data_ptr(),
|
||||
softmax_d.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
num_splits);
|
||||
|
||||
launch(params, stream, /*configure=*/true);
|
||||
|
||||
if (params.num_splits > 1) {
|
||||
if (!dq_tmp.defined()) {
|
||||
dq_tmp = at::zeros({total_q, num_heads, head_size}, opts.dtype(at::kFloat));
|
||||
params.o_tmp_ptr = dq_tmp.data_ptr(); // o_tmp stores dq_tmp in the backward pass
|
||||
} else {
|
||||
dq_tmp.zero_();
|
||||
}
|
||||
}
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
at::PhiloxCudaState philox_args;
|
||||
if (is_dropout) {
|
||||
if (at::cuda::currentStreamCaptureStatus() ==
|
||||
at::cuda::CaptureStatus::None)
|
||||
{
|
||||
philox_args = at::PhiloxCudaState(*philox_seed.data_ptr<int64_t>(), *philox_offset.data_ptr<int64_t>());
|
||||
} else { // dropout + capture
|
||||
philox_args = at::PhiloxCudaState(
|
||||
philox_seed.data_ptr<int64_t>(), philox_offset.data_ptr<int64_t>(), 0);
|
||||
}
|
||||
}
|
||||
params.philox_args = philox_args;
|
||||
|
||||
launch(params, stream, /*configure=*/false);
|
||||
|
||||
if (params.num_splits > 1) {
|
||||
dq.copy_(dq_tmp);
|
||||
}
|
||||
|
||||
return std::make_tuple(dq, dk, dv, softmax_d);
|
||||
}
|
||||
} // namespace fmha
|
||||
|
||||
#endif
|
||||
@ -1,50 +0,0 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
TORCH_API
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &out,
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
const int num_splits);
|
||||
|
||||
TORCH_API
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
|
||||
mha_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
|
||||
const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &out, // total_q x num_heads x head_size
|
||||
const at::Tensor &softmax_lse_, // b x h x s softmax logsumexp
|
||||
at::Tensor &dq, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dk, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
at::Tensor &dv, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_, // max sequence length to choose the kernel
|
||||
const float p_dropout, // probability to drop
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const int num_splits,
|
||||
at::Tensor philox_seed,
|
||||
at::Tensor philox_offset
|
||||
);
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_bwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_bwd_hdim128(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,21 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_bwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_bwd_hdim32(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,38 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_bwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_bwd_hdim64(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
FP16_SWITCH(params.is_bf16, ([&] {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (params.seqlen_k >= 256) {
|
||||
if ((dprops->major == 8 && dprops->minor == 0) ||
|
||||
(dprops->major == 9 && dprops->minor == 0)) {
|
||||
// Don't share smem for K & V, and don't keep V in registers
|
||||
// This speeds things up by 2-3% by avoiding register spills, but it
|
||||
// uses more shared memory, which is fine on A100 and H100 but not other
|
||||
// GPUs. For other GPUs, we keep V in registers.
|
||||
using Kernel_traits =
|
||||
FMHA_kernel_traits<256, 64, 16, 1, 8, 0x100u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (dprops->major == 8 && dprops->minor > 0) {
|
||||
using Kernel_traits =
|
||||
FMHA_kernel_traits<256, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
} else if (dprops->major == 7 && dprops->minor == 5) {
|
||||
using Kernel_traits =
|
||||
FMHA_kernel_traits<128, 64, 16, 1, 8, 0x08u, elem_type>;
|
||||
run_fmha_bwd_loop<Kernel_traits>(params, stream, configure);
|
||||
}
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,119 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_dgrad_kernel_1xN_loop.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
// Pick whether we should parallelize across seqlen_k (num_splits > 1) or not (num_splits=1).
|
||||
// Parallelizing will have better occupancy, but has some overhead due to having to zero out
|
||||
// dq_tmp and having to copy dq_tmp to dq.
|
||||
inline int num_splits_heuristic_bwd(int batch_nheads, int num_SMs, int ctas_per_sm, int seqlen,
|
||||
int blocksize, bool is_causal) {
|
||||
float n_waves_1 = float(batch_nheads) / (num_SMs * ctas_per_sm);
|
||||
float eff_1 = n_waves_1 / ceil(n_waves_1);
|
||||
int num_splits_parallel = seqlen / blocksize;
|
||||
float n_waves_parallel = float(batch_nheads * num_splits_parallel) / (num_SMs * ctas_per_sm);
|
||||
float eff_parallel_raw = n_waves_parallel / ceil(n_waves_parallel);
|
||||
float discount_factor;
|
||||
if (!is_causal) {
|
||||
discount_factor = 1.f + float(blocksize) / seqlen;
|
||||
} else { // For causal, parallelizing seems to help with load-balancing as well
|
||||
// For example, if headdim=128, seqlen >= 1280 always prefers parallel
|
||||
if (seqlen / blocksize >= 10) return num_splits_parallel;
|
||||
discount_factor = 1.f + 0.5 * float(blocksize) / seqlen;
|
||||
}
|
||||
float eff_parallel = eff_parallel_raw / discount_factor;
|
||||
return eff_1 >= eff_parallel ? 1 : num_splits_parallel;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
__global__ void fmha_bwd_dot_do_o_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dot_do_o<Kernel_traits>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1>
|
||||
__global__ void fmha_bwd_dq_dk_dv_loop_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dq_dk_dv_1xN<Kernel_traits, Is_dropout, Is_causal, loop_steps>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal>
|
||||
__global__ void fmha_bwd_q_dk_dv_loop_seqparallel_kernel(FMHA_dgrad_params params) {
|
||||
fmha::compute_dq_dk_dv_seqparallel<Kernel_traits, Is_dropout, Is_causal>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_bwd_loop(FMHA_dgrad_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
constexpr int smem_size_softmax = Kernel_traits::Cta_tile_p::M * Kernel_traits::Cta_tile_p::WARPS_N * sizeof(float);
|
||||
constexpr int smem_size_q = Kernel_traits::Smem_tile_q::BYTES_PER_TILE;
|
||||
constexpr int smem_size_v = Kernel_traits::Smem_tile_v::BYTES_PER_TILE;
|
||||
constexpr int smem_size_dq = Kernel_traits::Smem_tile_o::BYTES_PER_TILE;
|
||||
|
||||
using Smem_tile_s = fmha::Smem_tile_mma_transposed<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int smem_size_s = Smem_tile_s::BYTES_PER_TILE;
|
||||
static_assert(smem_size_s == 16 * Kernel_traits::Cta_tile_p::N * 2);
|
||||
static_assert(smem_size_dq == 16 * Kernel_traits::Cta_tile_p::K * 4 * Kernel_traits::Cta_tile_p::WARPS_N);
|
||||
|
||||
constexpr int smem_size_dq_dk_dv = smem_size_q * 2 + smem_size_v * (Kernel_traits::V_IN_REGS ? 1 : 2) + smem_size_dq + smem_size_s * 2;
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
// printf("blocksize_c = %d, WARPS_N = %d, Smem size = %d\n", blocksize_c, Kernel_traits::Cta_tile_p::WARPS_N, smem_size_dq_dk_dv);
|
||||
|
||||
bool is_dropout = params.p_dropout < 1.f; // params.p_dropout is the probability of "keeping"
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
BOOL_SWITCH(is_dropout, IsDropoutConst, ([&] {
|
||||
auto kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/1>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/1>;
|
||||
} else if (params.seqlen_k == blocksize_c * 2) {
|
||||
kernel = params.is_causal
|
||||
? &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, true, /*loop_steps=*/2>
|
||||
: &fmha_bwd_dq_dk_dv_loop_kernel<Kernel_traits, IsDropoutConst, false, /*loop_steps=*/2>;
|
||||
}
|
||||
auto kernel_seqparallel = params.is_causal
|
||||
? &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, true>
|
||||
: &fmha_bwd_q_dk_dv_loop_seqparallel_kernel<Kernel_traits, IsDropoutConst, false>;
|
||||
if( smem_size_dq_dk_dv >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel_seqparallel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
|
||||
}
|
||||
// Automatically set num_splits to maximize occupancy
|
||||
if (params.num_splits <= 0) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size_dq_dk_dv);
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
|
||||
// We don't want more than 10 splits due to numerical error.
|
||||
// Numerical error on dk/dv scales as sqrt(num_splits).
|
||||
params.num_splits = num_splits_heuristic_bwd(
|
||||
params.b * params.h, dprops->multiProcessorCount,
|
||||
ctas_per_sm, params.seqlen_k, blocksize_c, params.is_causal
|
||||
);
|
||||
}
|
||||
if (configure) return;
|
||||
if (params.num_splits == 1) {
|
||||
dim3 grid(params.b, params.h, params.num_splits);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
} else {
|
||||
dim3 grid_dot(params.b, params.h, (params.seqlen_q + 128 - 1) / 128);
|
||||
fmha_bwd_dot_do_o_kernel<Kernel_traits><<<grid_dot, Kernel_traits::THREADS, 0, stream>>>(params);
|
||||
int num_splits = params.seqlen_k / blocksize_c; // seqlen_k is divisible by blocksize_c
|
||||
dim3 grid(params.b, params.h, num_splits);
|
||||
kernel_seqparallel<<<grid, Kernel_traits::THREADS, smem_size_dq_dk_dv, stream>>>(params);
|
||||
}
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,839 +0,0 @@
|
||||
/* Copyright (c) 2022, Tri Dao.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <int ROWS, int THREADS_PER_ROW, typename elem_type=__half, int M, typename Gmem_softmax_sum>
|
||||
inline __device__ void dot_do_o(const uint4 (&do_)[M], const uint4 (&o)[M], const float scale,
|
||||
Gmem_softmax_sum gmem_softmax_d, int tidx) {
|
||||
float sum[M];
|
||||
fmha::SumOp<float> sum_op;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; ++mi) {
|
||||
sum[mi] = fmha::Allreduce<THREADS_PER_ROW>::run(
|
||||
fmha::hmulsum8<elem_type>(do_[mi], o[mi]), sum_op
|
||||
) * scale;
|
||||
}
|
||||
const int dp_sum_row = tidx / THREADS_PER_ROW;
|
||||
if ((dp_sum_row < ROWS) && (tidx % THREADS_PER_ROW == 0)) {
|
||||
gmem_softmax_d.store_row(reinterpret_cast<const uint32_t (&)[M]>(sum), dp_sum_row);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
|
||||
// This is used in the case where we want to parallelize the backward across seqlen_k.
|
||||
template<typename Kernel_traits, typename Params>
|
||||
inline __device__ void compute_dot_do_o(const Params ¶ms) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 3rd batched GEMM.
|
||||
using Cta_tile_dkv =
|
||||
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
|
||||
|
||||
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
|
||||
static_assert(Cta_tile_dkv::K == 16);
|
||||
|
||||
// The global memory tile to load dO.
|
||||
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
|
||||
|
||||
// The global memory tile to load O.Loading O here is similar to loading dO.
|
||||
using Gmem_tile_o = Gmem_tile_do;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// How many steps to jump per iteration.
|
||||
const int step_stride = gridDim.z;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
if( binfo.stop_early() ) return;
|
||||
|
||||
// Allocate the global memory tile loader for dO.
|
||||
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
|
||||
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M;
|
||||
// Wind gmem tiles to the correct position.
|
||||
gmem_do.move(blockIdx.z);
|
||||
gmem_o.move(blockIdx.z);
|
||||
gmem_softmax_d.move(blockIdx.z);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = blockIdx.z; l < steps; l += step_stride) {
|
||||
if (l * Cta_tile_p::M >= binfo.actual_seqlen_q)
|
||||
break;
|
||||
|
||||
gmem_do.load();
|
||||
gmem_do.move(step_stride);
|
||||
gmem_o.load();
|
||||
gmem_o.move(step_stride);
|
||||
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
gmem_softmax_d.move(step_stride);
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params, typename Prng>
|
||||
inline __device__ void compute_dq_dk_dv_1xN_one_iter(const Params ¶ms, Prng &ph,
|
||||
const int loop_step_idx) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_dq = typename Kernel_traits::Cta_tile_o;
|
||||
// The description of the CTA tile for the 3rd batched GEMM.
|
||||
using Cta_tile_dkv =
|
||||
fmha::Cta_tile_extd<Cta_tile_p::N, Cta_tile_p::K, Cta_tile_p::M, Cta_tile_p::WARPS_N, 1, Cta_tile_p::WARPS_M>;
|
||||
|
||||
static_assert(Cta_tile_dkv::M == 512 || Cta_tile_dkv::M == 256 || Cta_tile_dkv::M == 128);
|
||||
static_assert(Cta_tile_dkv::N == 16 || Cta_tile_dkv::N == 32 || Cta_tile_dkv::N == 64 || Cta_tile_dkv::N == 128);
|
||||
static_assert(Cta_tile_dkv::K == 16);
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_dq = fmha::Hmma_tile<Cta_tile_dq>;
|
||||
// The MMA tile for the 3rd GEMM.
|
||||
using Mma_tile_dkv = fmha::Hmma_tile<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
// The shared memory tile to reload Q transposed.
|
||||
using Smem_tile_qt = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle K^T. Treat K^T as V
|
||||
using Smem_tile_kt = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// Treating V as K. We need to use Kernel_traits::Smem_tile_k otherwise loading will be wrong
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_k;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_k;
|
||||
|
||||
// The global memory tile to load dO.
|
||||
using Gmem_tile_do = typename Kernel_traits::Gmem_tile_do;
|
||||
// The shared memory tile to load dO.
|
||||
// Treating dO as Q.
|
||||
using Smem_tile_do = typename Kernel_traits::Smem_tile_q;
|
||||
// The shared memory tile to reload dO transposed.
|
||||
using Smem_tile_dot = fmha::Smem_tile_b<Cta_tile_dkv, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
|
||||
// The global memory tile to load O.Loading O here is similar to loading dO.
|
||||
using Gmem_tile_o = Gmem_tile_do;
|
||||
|
||||
// The global memory tile to store dQ.
|
||||
using Gmem_tile_dq = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_dq_tmp = fmha::Gmem_tile_o<Cta_tile_dq, 4>;
|
||||
// The shared memory tile to swizzle dQ.
|
||||
using Smem_tile_dq = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
// The global memory tile to store dV.
|
||||
using Gmem_tile_dv = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dV.
|
||||
using Smem_tile_dv = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
|
||||
// The global memory tile to store dK.
|
||||
using Gmem_tile_dk = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle dK.
|
||||
using Smem_tile_dk = fmha::Smem_tile_mma_epilogue<Cta_tile_dkv>;
|
||||
static_assert(Smem_tile_dk::NUM_LDS == Gmem_tile_dk::LDGS);
|
||||
static_assert(Smem_tile_dk::THREADS_PER_ROW == Gmem_tile_dk::THREADS_PER_ROW);
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Smem_tile_st = typename Kernel_traits::Smem_tile_st;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
// using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, /*K-in_regs=*/false, elem_type>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
// Shared memory layout if we keep V in registers:
|
||||
// dO | Q | K / V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
// Shared memory layout if we keep V shared memory:
|
||||
// dO | Q | K | V | dQ | S | dP | dP_sum
|
||||
// dV | dK
|
||||
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Gemm1 gemm_q_k(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for dQ.
|
||||
Gmem_tile_dq gmem_dq(params.dq_ptr, params.dq_row_stride_in_elts, params.dq_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_dq_tmp gmem_dq_tmp(params.o_tmp_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
// Allocate the shared memory tile loader for K^T. We use the same as K so be careful!!!
|
||||
Smem_tile_kt smem_kt(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::Smem_tile_q::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for dO.
|
||||
Gmem_tile_do gmem_do(params.do_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the shared memory tile loader for dO.
|
||||
Smem_tile_do smem_do(&smem_[0], tidx);
|
||||
Smem_tile_dot smem_dot(&smem_[0], tidx);
|
||||
// Allocate the shared memory tile loader for Q^T.
|
||||
// TODO: assert that this points to the same memory as gemm_q_k.smem_q
|
||||
Smem_tile_qt smem_qt(&smem_[Smem_tile_do::BYTES_PER_TILE], tidx);
|
||||
|
||||
Smem_tile_st smem_s(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE], tidx);
|
||||
Smem_tile_st smem_dp(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O + Smem_tile_dq::BYTES_PER_TILE + Smem_tile_st::BYTES_PER_TILE], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_dq smem_dq(&smem_[Smem_tile_do::BYTES_PER_TILE + Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_d(params.dsoftmax_sum, params, tidx);
|
||||
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
|
||||
// Otherwise we'd be reading out-of-bound memory before the loop
|
||||
if (begin * Cta_tile_p::M >= binfo.actual_seqlen_q) {
|
||||
// Still need to zero out dk and dv before returning
|
||||
static_assert(Smem_tile_dk::NUM_LDS == Smem_tile_dv::NUM_LDS);
|
||||
uint4 dkv_out[Smem_tile_dk::NUM_LDS];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Smem_tile_dk::NUM_LDS; ++i) { dkv_out[i] = make_uint4(0u, 0u, 0u, 0u); }
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) { gmem_dk.move(loop_step_idx); }
|
||||
gmem_dk.store(dkv_out);
|
||||
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) { gmem_dv.move(loop_step_idx); }
|
||||
gmem_dv.store(dkv_out);
|
||||
return;
|
||||
}
|
||||
|
||||
const int steps = (params.seqlen_q + Cta_tile_p::M - 1) / Cta_tile_p::M - begin;
|
||||
// Wind gmem tiles to the correct position.
|
||||
gmem_q.move(begin);
|
||||
gmem_do.move(begin);
|
||||
gmem_o.move(begin);
|
||||
if (!Seq_parallel) { gmem_dq.move(begin); } // If Seq_parallel, we're not using gmem_dq at all
|
||||
gmem_dq_tmp.move(begin);
|
||||
// TODO: need to move gmem_s if we want the intermediate result for debugging
|
||||
gmem_softmax_lse.move(begin);
|
||||
gmem_softmax_d.move(begin);
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
// Trigger the loads for dO.
|
||||
gmem_do.load();
|
||||
// Trigger the loads for O.
|
||||
if (Is_first) { gmem_o.load(); }
|
||||
|
||||
float p_lse[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
// Commit the data for Q, dO, and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_do.commit(smem_do);
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
|
||||
// // Instead of scaling dP by rp_dropout, we scale V instead
|
||||
// if (Is_dropout) {
|
||||
// const uint32_t scale_dropout = params.scale_dropout;
|
||||
// #pragma unroll
|
||||
// for(int it=0; it < Gmem_tile_v::LDGS; it++){
|
||||
// gmem_v.fetch_[it] = fmha::hmul8(scale_dropout, gmem_v.fetch_[it]);
|
||||
// }
|
||||
// }
|
||||
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0; it < Gmem_tile_k::LDGS; it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Kernel_traits::V_IN_REGS ? Mma_tile_p::MMAS_K : 2][Mma_tile_p::MMAS_N];
|
||||
if (Kernel_traits::V_IN_REGS) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
}
|
||||
|
||||
float dp_sum[Mma_tile_p::MMAS_M * 2];
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
// Load the fragments for K^T.
|
||||
// typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
// smem_kt.load(frag_kt[0], 0);
|
||||
// typename Smem_tile_kt::Fragment frag_kt[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_N];
|
||||
// #pragma unroll
|
||||
// for( int ki = 0; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// smem_kt.load(frag_kt[ki], ki);
|
||||
// }
|
||||
|
||||
// Create the object to do the softmax.
|
||||
// We won't be using the shared memory for this softmax at all
|
||||
Softmax softmax(params, smem_, tidx);
|
||||
|
||||
// Declare the accumulators for the 3rd gemm.
|
||||
fmha::Fragment_accumulator acc_dv[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dv);
|
||||
fmha::Fragment_accumulator acc_dk[Mma_tile_dkv::MMAS_M][Mma_tile_dkv::MMAS_N];
|
||||
fmha::Clear_accumulator<fmha::Accumulator_type, Cta_tile_dkv::WARPS_K>::apply(acc_dk);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = 0; l < steps; l++) {
|
||||
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q)
|
||||
break;
|
||||
|
||||
// Load the fragments for V.
|
||||
// typename Smem_tile_v::Fragment frag_v[2][Mma_tile_p::MMAS_N];
|
||||
if (!Kernel_traits::V_IN_REGS) { smem_v.load(frag_v[0], 0); }
|
||||
|
||||
// Load the fragments for dO.
|
||||
typename Smem_tile_do::Fragment frag_do[2][Mma_tile_p::MMAS_M];
|
||||
smem_do.load(frag_do[0], 0);
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(begin + l);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
// Scale by log-sum-exp of the softmax
|
||||
// softmax.apply_exp(p_lse);
|
||||
softmax.template scale_apply_exp</*scale_max=*/false>(p_lse, params.scale_bmm1f);
|
||||
if (Is_dropout) {
|
||||
// softmax.apply_dropout(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t);
|
||||
unsigned int warp_idx = threadIdx.x / 32;
|
||||
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
|
||||
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
|
||||
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
|
||||
softmax.template apply_dropout_16bits</*encode_dropout_in_sign_bit=*/true>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_dq::MMAS_K][Mma_tile_dq::MMAS_M];
|
||||
static_assert(Mma_tile_dq::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_dq::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
|
||||
// Store s * dmask to smem for transpose
|
||||
smem_s.store(frag_p);
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
if (l + 1 < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move();
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
|
||||
// // if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
// __syncthreads();
|
||||
// }
|
||||
|
||||
fmha::Fragment_accumulator acc_dp[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; ++mi) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ++ni) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 8; ++ii) {
|
||||
acc_dp[mi][ni].elt(ii) = -dp_sum[mi * 2 + ((ii / 2) % 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do this part of dP^T = (dO * V^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of dO values.
|
||||
smem_do.load(frag_do[ki & 1], ki);
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
smem_v.load(frag_v[ki & 1], ki);
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[ki - 1]);
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l < 4)) {
|
||||
// float2 tmp = __half22float2(reinterpret_cast<__half2 &>(frag_do[(ki - 1) & 1]));
|
||||
// printf("frag_do=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// tmp = __half22float2(reinterpret_cast<__half2 &>(frag_v[(ki - 1) & 1]));
|
||||
// printf("frag_v=%.6f, %.6f\n", tmp.x, tmp.y);
|
||||
// }
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
if (!Kernel_traits::V_IN_REGS) {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1) & 1]);
|
||||
} else {
|
||||
fmha::gemm_cl<elem_type>(acc_dp, frag_do[(ki - 1) & 1], frag_v[(ki - 1)]);
|
||||
}
|
||||
}
|
||||
|
||||
auto pointwise_mult = [](float p, float dp, float d) {
|
||||
return p * ((!Is_dropout) || p >= 0.f ? dp : d);
|
||||
};
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M; mi++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < Mma_tile_p::MMAS_N; ni++) {
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 0], acc_dp[mi][ni].elt(0), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 1], acc_dp[mi][ni].elt(1), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 2], acc_dp[mi][ni].elt(4), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 0][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 0][4 * ni + 3], acc_dp[mi][ni].elt(5), dp_sum[2 * mi + 0]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 0] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 0], acc_dp[mi][ni].elt(2), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 1] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 1], acc_dp[mi][ni].elt(3), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 2] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 2], acc_dp[mi][ni].elt(6), dp_sum[2 * mi + 1]);
|
||||
softmax.elt_[2 * mi + 1][4 * ni + 3] = pointwise_mult(softmax.elt_[2 * mi + 1][4 * ni + 3], acc_dp[mi][ni].elt(7), dp_sum[2 * mi + 1]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the fragments for K^T.
|
||||
typename Smem_tile_kt::Fragment frag_kt[2][Mma_tile_dq::MMAS_N];
|
||||
smem_kt.load(frag_kt[0], 0);
|
||||
|
||||
// Trigger the load for the next dO values.
|
||||
if (l + 1 < steps) {
|
||||
smem_do.move_to_next_write_buffer();
|
||||
gmem_do.move();
|
||||
gmem_do.load();
|
||||
if (Is_first) {
|
||||
gmem_o.move();
|
||||
gmem_o.load();
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
|
||||
// Store dp to smem for transpose
|
||||
smem_dp.store(frag_p);
|
||||
|
||||
// gmem_s.store(frag_p, mask);
|
||||
// gmem_s.move();
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_dq[Mma_tile_dq::MMAS_M][Mma_tile_dq::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_dq::WARPS_K>::apply(acc_dq);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dq::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_kt.load(frag_kt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dq::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1) & 1]);
|
||||
// fmha::gemm_cl<elem_type>(acc_dq, frag_p[ki - 1], frag_kt[(ki - 1)]);
|
||||
}
|
||||
|
||||
static_assert(Gmem_tile_dq::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
// Need to syncthreads here, otherwise the smem_dq reads from the previous iteration
|
||||
// might happen after the smem_dq writes in this iteration.
|
||||
__syncthreads();
|
||||
smem_dq.store(acc_dq, 0);
|
||||
|
||||
typename Smem_tile_dot::Fragment frag_dot[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_dot::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_dot.load(frag_dot[0], 0);
|
||||
|
||||
// Threads in a warp is communicating via shared memory (smem_s and smem_dp)
|
||||
__syncwarp();
|
||||
typename Smem_tile_st::Fragment frag_s[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_s.load(frag_s);
|
||||
|
||||
if (Is_dropout) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_dkv::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
frag_s[ki][mi].template hrelu_<elem_type>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_dot.load(frag_dot[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dv, frag_s[(ki - 1)], frag_dot[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// float2 tmp0 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][0]));
|
||||
// printf("frag_dot[0][0]=%.6f, %.6f\n", tmp0.x, tmp0.y);
|
||||
// float2 tmp1 = __half22float2(reinterpret_cast<__half2 &>(frag_dot[0][1]));
|
||||
// printf("frag_dot[0][1]=%.6f, %.6f\n", tmp1.x, tmp1.y);
|
||||
// }
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l = %d, acc_dv[0][0]=%.6f, %.6f\n", l, acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
|
||||
// printf("l = %d, acc_dv[0][1]=%.6f, %.6f\n", l, acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
|
||||
// }
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
uint4 dq_out[Gmem_tile_dq::STGS_PER_LOOP];
|
||||
if (!Is_first && !Seq_parallel) { gmem_dq_tmp.load(dq_out, 0); }
|
||||
|
||||
// __syncthreads();
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gmem_do.commit(smem_do);
|
||||
gmem_softmax_d.move();
|
||||
if (Is_first) {
|
||||
dot_do_o<Gmem_tile_do::ROWS, Gmem_tile_do::THREADS_PER_ROW, elem_type>(
|
||||
gmem_do.fetch_, gmem_o.fetch_, params.p_dropout, gmem_softmax_d, tidx
|
||||
);
|
||||
}
|
||||
gmem_softmax_lse.move();
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_lse));
|
||||
}
|
||||
|
||||
typename Smem_tile_st::Fragment frag_dpt[Mma_tile_dkv::MMAS_K][Mma_tile_dkv::MMAS_M];
|
||||
smem_dp.load(frag_dpt);
|
||||
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
typename Smem_tile_qt::Fragment frag_qt[2][Mma_tile_dkv::MMAS_N];
|
||||
static_assert(Smem_tile_qt::Fragment::NUM_REGS == 4);
|
||||
static_assert(Mma_tile_dkv::MMAS_K == 1);
|
||||
smem_qt.load(frag_qt[0], 0);
|
||||
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_dkv::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
smem_qt.load(frag_qt[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_dkv::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_dk, frag_dpt[(ki - 1)], frag_qt[(ki - 1) & 1]);
|
||||
}
|
||||
|
||||
// Make sure dQ is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
if (l + 1 < steps) {
|
||||
gmem_softmax_d.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(dp_sum));
|
||||
}
|
||||
|
||||
// Load from shared memory.
|
||||
smem_dq.template load</*zero_init=*/Is_first || Seq_parallel>(dq_out);
|
||||
|
||||
if (!Seq_parallel) {
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
if (is_final_write) {
|
||||
// if (Is_dropout) {
|
||||
// dq_out[0] = fmha::fmul4(dq_out[0], params.rp_dropout);
|
||||
// }
|
||||
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
|
||||
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
|
||||
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
// Output the values.
|
||||
gmem_dq.template store<elem_type>(dq_out, 0);
|
||||
// Move to the next part of the output.
|
||||
gmem_dq.move();
|
||||
// TODO: for parallel, need to deal with the dropout scaling
|
||||
} else {
|
||||
// Output the values.
|
||||
gmem_dq_tmp.store(dq_out, 0);
|
||||
}
|
||||
} else {
|
||||
// We always scale dq_out before writing in this case, since we don't want to
|
||||
// have to scale at the end when copying from dq_tmp to dq.
|
||||
for (int jj = 0; jj < Gmem_tile_dq::STGS_PER_LOOP; ++jj) {
|
||||
// dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1f);
|
||||
dq_out[jj] = fmha::fmul4(dq_out[jj], params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
gmem_dq_tmp.atomic_add(dq_out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
if (!(Is_first && Is_last)) { gmem_dq_tmp.move(); }
|
||||
|
||||
// // Make sure the data is in shared memory.
|
||||
// __syncthreads();
|
||||
|
||||
// Commit the values for Q and dO into shared memory.
|
||||
if (l + 1 < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
gemm_q_k.reload_q();
|
||||
smem_qt.move_to_next_read_buffer();
|
||||
// smem_qt.load(frag_qt[0], 0);
|
||||
smem_do.move_to_next_read_buffer();
|
||||
smem_dot.move_to_next_read_buffer();
|
||||
// smem_dot.load(frag_dot[0], 0);
|
||||
}
|
||||
|
||||
} // Outer loop over the sequence length.
|
||||
|
||||
if (Is_dropout) {
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
acc_dv[mi][ni].mul_(params.rp_dropout);
|
||||
}
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dv[0][0]=%.6f, %.6f\n", acc_dv[0][0].elt(2), acc_dv[0][0].elt(3));
|
||||
// printf("l final, acc_dv[0][1]=%.6f, %.6f\n", acc_dv[0][1].elt(2), acc_dv[0][1].elt(3));
|
||||
// }
|
||||
for( int mi = 0; mi < Mma_tile_dkv::MMAS_M; mi++ ) {
|
||||
for( int ni = 0; ni < Mma_tile_dkv::MMAS_N; ni++ ) {
|
||||
// acc_dk[mi][ni].mul_(Is_dropout ? params.rp_dropout * params.scale_bmm1f : params.scale_bmm1f);
|
||||
// acc_dk[mi][ni].mul_(params.scale_bmm1f);
|
||||
acc_dk[mi][ni].mul_(params.scale_bmm1_rp_dropout);
|
||||
}
|
||||
}
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("l final, acc_dk=%.6f, %.6f\n", acc_dk[0][0].elt(0), acc_dk[0][0].elt(1));
|
||||
// }
|
||||
|
||||
__syncthreads();
|
||||
// TODO [TD - 2022-05-04]: Are there cases where the shared mem for dV and dK are larger than
|
||||
// the total amount of shared mem?
|
||||
// Epilogue swizzle for dV
|
||||
Smem_tile_dv smem_dv(&smem_[0], tidx);
|
||||
smem_dv.template store<elem_type>(acc_dv);
|
||||
|
||||
// Epilogue swizzle for dK
|
||||
Smem_tile_dk smem_dk(&smem_[Smem_tile_dv::BYTES_PER_TILE], tidx);
|
||||
smem_dk.template store<elem_type>(acc_dk);
|
||||
|
||||
__syncthreads();
|
||||
uint4 dv_out[Smem_tile_dv::NUM_LDS];
|
||||
smem_dv.load(dv_out);
|
||||
Gmem_tile_dv gmem_dv(params.dv_ptr, params.dv_row_stride_in_elts, params.dv_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dv.move(loop_step_idx);
|
||||
}
|
||||
gmem_dv.store(dv_out);
|
||||
|
||||
uint4 dk_out[Smem_tile_dk::NUM_LDS];
|
||||
smem_dk.load(dk_out);
|
||||
Gmem_tile_dk gmem_dk(params.dk_ptr, params.dk_row_stride_in_elts, params.dk_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
if (!Is_first) {
|
||||
gmem_dk.move(loop_step_idx);
|
||||
}
|
||||
gmem_dk.store(dk_out);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// loop_steps = -1 means the number of steps will be params.seqlen_k / Kernel_traits::Cta_tile_p::N.
|
||||
// This template parameter is there so we can specialize with loop_steps == 1 and loop_steps == 2.
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, int loop_steps=-1, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_1xN(const Params ¶ms) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
if (loop_steps == 1) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else if (loop_steps == 2) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, 1);
|
||||
} else {
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, true>(params, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, true, false>(params, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false>(params, ph, loop_step_idx);
|
||||
}
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, true>(params, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, typename Params>
|
||||
inline __device__ void compute_dq_dk_dv_seqparallel(const Params ¶ms) {
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
|
||||
int loop_step_idx = blockIdx.z;
|
||||
compute_dq_dk_dv_1xN_one_iter<Kernel_traits, Is_dropout, Is_causal, false, false, /*Seq_parallel=*/true>(params, ph, loop_step_idx);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,707 +0,0 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct Gemm_Q_K_base {
|
||||
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
|
||||
using Smem_tile_q = typename Kernel_traits::Smem_tile_q;
|
||||
using Smem_tile_k = typename Kernel_traits::Smem_tile_k;
|
||||
using Fragment_q = typename Smem_tile_q::Fragment;
|
||||
using Fragment_k = typename Smem_tile_k::Fragment;
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
|
||||
static constexpr int SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
|
||||
|
||||
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k, const int tidx)
|
||||
: smem_q(smem_ptr_q, tidx)
|
||||
, smem_k(smem_ptr_k, tidx) {
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void load_q() {
|
||||
smem_q.load(frag_q[0], 0);
|
||||
}
|
||||
|
||||
__device__ inline void reload_q() {
|
||||
smem_q.load(frag_q[0], 0);
|
||||
}
|
||||
|
||||
Fragment_q frag_q[2][Mma_tile_p::MMAS_M];
|
||||
Smem_tile_q smem_q;
|
||||
Smem_tile_k smem_k;
|
||||
};
|
||||
|
||||
template<typename Kernel_traits, bool K_in_regs, typename elem_type_=__half>
|
||||
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
|
||||
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Smem_tile_o = typename Base::Smem_tile_o;
|
||||
using Smem_tile_q = typename Base::Smem_tile_q;
|
||||
using Smem_tile_k = typename Base::Smem_tile_k;
|
||||
using Fragment_k = typename Base::Fragment_k;
|
||||
using Mma_tile_p = typename Base::Mma_tile_p;
|
||||
using elem_type = elem_type_;
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
// If V is stored in shared memory, we can't load K using the same shared memory.
|
||||
static_assert(Kernel_traits::V_IN_REGS);
|
||||
|
||||
static constexpr int SMEM_OFFSET_O = Smem_tile_q::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
|
||||
|
||||
// Q | K / V
|
||||
// | O | SOFTMAX
|
||||
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
|
||||
+ std::max((SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE,
|
||||
Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX);
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
|
||||
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
Base::smem_k.load(frag_k[ki], ki);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename Acc, int M, int N>
|
||||
__device__ inline void operator()(Acc (&acc_p)[M][N]){
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
Base::smem_q.load(Base::frag_q[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1)]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void reload_k(){
|
||||
// Noop.
|
||||
}
|
||||
|
||||
Fragment_k frag_k[Mma_tile_p::MMAS_K][Mma_tile_p::MMAS_N];
|
||||
};
|
||||
|
||||
|
||||
template<typename Kernel_traits, typename elem_type_>
|
||||
struct Gemm_Q_K<Kernel_traits, false, elem_type_> : public Gemm_Q_K_base<Kernel_traits> {
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Smem_tile_o = typename Base::Smem_tile_o;
|
||||
using Smem_tile_q = typename Base::Smem_tile_q;
|
||||
using Smem_tile_k = typename Base::Smem_tile_k;
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
|
||||
using Fragment_k = typename Base::Fragment_k;
|
||||
using Mma_tile_p = typename Base::Mma_tile_p;
|
||||
using elem_type = elem_type_;
|
||||
Fragment_k frag_k[2][Mma_tile_p::MMAS_N];
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
|
||||
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
|
||||
|
||||
static constexpr int SMEM_OFFSET_V = Smem_tile_q::BYTES_PER_TILE + (SHARE_SMEM_FOR_K_AND_V ? 0 : Smem_tile_k::BYTES_PER_TILE);
|
||||
static_assert(Smem_tile_v::BYTES_PER_TILE == (int) Smem_tile_k::BYTES_PER_TILE);
|
||||
static constexpr int SMEM_OFFSET_O = SMEM_OFFSET_V + Smem_tile_v::BYTES_PER_TILE;
|
||||
static constexpr int SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + Smem_tile_o::BYTES_PER_TILE;
|
||||
|
||||
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
|
||||
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
|
||||
static constexpr int SMEM_BYTES = Smem_tile_q::BYTES_PER_TILE
|
||||
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Smem_tile_k::BYTES_PER_TILE
|
||||
+ Smem_tile_o::BYTES_PER_TILE + Base::SMEM_BYTES_SOFTMAX;
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_, const int tidx)
|
||||
: Base(smem_, smem_ + Smem_tile_q::BYTES_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
Base::smem_k.load(frag_k[0], 0);
|
||||
}
|
||||
|
||||
template<typename Acc, int M, int N>
|
||||
__device__ inline void operator()(Acc (&acc_p)[M][N]){
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 1; ki < Mma_tile_p::MMAS_K; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
Base::smem_q.load(Base::frag_q[ki & 1], ki);
|
||||
Base::smem_k.load(frag_k[ki & 1], ki);
|
||||
// Do the math for the values already in registers.
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
|
||||
}
|
||||
// Do the final stage of math.
|
||||
{
|
||||
int ki = Mma_tile_p::MMAS_K;
|
||||
fmha::gemm_cl<elem_type>(acc_p, Base::frag_q[(ki - 1) & 1], frag_k[(ki - 1) & 1]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void reload_k(){
|
||||
Base::smem_k.load(frag_k[0], 0);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename Kernel_traits>
|
||||
constexpr size_t get_dynamic_smem_size(){
|
||||
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int steps, Prng &ph, const int loop_step_idx) {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using elem_type = typename Kernel_traits::elem_type;
|
||||
#else
|
||||
constexpr bool is_fp16_type = std::is_same<typename Kernel_traits::elem_type, __half>::value;
|
||||
assert(is_fp16_type);
|
||||
using elem_type = __half;
|
||||
#endif
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = typename Kernel_traits::Gmem_tile_q;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = typename Kernel_traits::Gmem_tile_k;
|
||||
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = typename Kernel_traits::Gmem_tile_v;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = typename Kernel_traits::Smem_tile_v;
|
||||
|
||||
// The global memory tile to store O.
|
||||
using Gmem_tile_o = typename Kernel_traits::Gmem_tile_o;
|
||||
using Gmem_tile_o_tmp = fmha::Gmem_tile_o<Cta_tile_o, 4>;
|
||||
// The shared memory tile to swizzle O.
|
||||
using Smem_tile_o = typename Kernel_traits::Smem_tile_o;
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_softmax_sum = typename Kernel_traits::Smem_dp_sum;
|
||||
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS, elem_type>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// How many steps to jump per iteration, which is the same as params.num_splits.
|
||||
const int step_stride = gridDim.z;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
// if( binfo.stop_early() ) return;
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Gemm1 gemm_q_k(smem_, tidx);
|
||||
// Allocate the global memory tile loader for Q.
|
||||
Gmem_tile_q gmem_q(params.q_ptr, params.q_row_stride_in_elts, params.q_head_stride_in_elts,
|
||||
params.d, binfo, tidx, true);
|
||||
// Allocate the global memory tile loader for O.
|
||||
Gmem_tile_o gmem_o(params.o_ptr, params.o_row_stride_in_elts, params.o_head_stride_in_elts,
|
||||
params.d, binfo, tidx);
|
||||
Gmem_tile_o_tmp gmem_o_tmp(params.o_tmp_ptr, params.o_tmp_row_stride_in_elts,
|
||||
params.o_tmp_head_stride_in_elts, params.d, binfo, tidx);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
int begin = Is_causal ? loop_step_idx * Cta_tile_p::N / Cta_tile_p::M : 0;
|
||||
// We want begin to be a multiple of gridDim.z
|
||||
// This is because the row indices processed by each threadblock must align between the
|
||||
// loop steps, otherwise we have a dependency between the blocks.
|
||||
// For example, threadblock with blockIdx.z == 1 must process row indices that are
|
||||
// k * gridDim.z + 1 for integer k.
|
||||
const int begin_mod_z = begin % gridDim.z;
|
||||
begin = begin_mod_z <= blockIdx.z ? begin - begin_mod_z : begin + gridDim.z - begin_mod_z;
|
||||
// Otherwise we'd be reading out-of-bound memory before the loop
|
||||
if ((begin + blockIdx.z) * Cta_tile_p::M >= binfo.actual_seqlen_q) return;
|
||||
const int steps_og = steps;
|
||||
steps -= begin;
|
||||
gmem_q.move(begin + blockIdx.z);
|
||||
gmem_o.move(begin + blockIdx.z);
|
||||
gmem_o_tmp.move(begin + blockIdx.z);
|
||||
if (Return_softmax) {
|
||||
gmem_s.move(begin + blockIdx.z);
|
||||
}
|
||||
gmem_softmax_lse.move(begin + blockIdx.z);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("begin = %d, steps = %d\n", begin, steps);
|
||||
// }
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
Gmem_tile_k gmem_k(params.k_ptr, params.k_row_stride_in_elts, params.k_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// Allocate the global memory tile loader for V.
|
||||
Gmem_tile_v gmem_v(params.v_ptr, params.v_row_stride_in_elts, params.v_head_stride_in_elts,
|
||||
params.d, binfo, tidx, false);
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_ = &smem_[Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
Smem_tile_v smem_v(smem_v_, tidx);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
Smem_tile_o smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
if (!Is_first) {
|
||||
gmem_k.move(loop_step_idx);
|
||||
gmem_v.move(loop_step_idx);
|
||||
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
|
||||
}
|
||||
|
||||
// Trigger the loads for K.
|
||||
gmem_k.load();
|
||||
// Trigger the loads for Q.
|
||||
gmem_q.load();
|
||||
// Trigger the loads for V.
|
||||
gmem_v.load();
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
|
||||
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
|
||||
}
|
||||
|
||||
// Commit the data for Q and V to shared memory.
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
gmem_v.commit(smem_v);
|
||||
|
||||
// const uint32_t scale_bmm1 = reinterpret_cast<const uint32_t&>(params.scale_bmm1);
|
||||
// #pragma unroll
|
||||
// for(int it=0;it < Gmem_tile_k::LDGS;it++){
|
||||
// gmem_k.fetch_[it] = fmha::hmul8(scale_bmm1, gmem_k.fetch_[it]);
|
||||
// }
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire kernel.
|
||||
typename Smem_tile_v::Fragment frag_v[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_N];
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
smem_v.load(frag_v[ki], ki);
|
||||
}
|
||||
|
||||
// Commit the data for V to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for V.
|
||||
gmem_k.commit(gemm_q_k.smem_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
|
||||
// Create the object to do the softmax.
|
||||
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
|
||||
|
||||
Smem_softmax_sum smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]), tidx);
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for (int l = blockIdx.z; l < steps; l += step_stride) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z <= 1)) {
|
||||
// printf("l = %d\n", l);
|
||||
// }
|
||||
if ((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
fmha::Fragment_accumulator acc_p[Mma_tile_p::MMAS_M][Mma_tile_p::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_p::WARPS_K>::apply(acc_p);
|
||||
|
||||
// Do this part of P = Q * K^T.
|
||||
gemm_q_k(acc_p);
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("acc_p=%.6f, %.6f\n", acc_p[0][0].elt(0), acc_p[0][0].elt(1));
|
||||
// }
|
||||
|
||||
uint4 out[Gmem_tile_o::STGS_PER_LOOP];
|
||||
if (!Is_first) { gmem_o_tmp.load(out, 0); }
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
if (l + step_stride < steps) {
|
||||
gemm_q_k.smem_q.move_to_next_write_buffer();
|
||||
gmem_q.move(step_stride);
|
||||
gmem_q.load();
|
||||
}
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(begin + l);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l < step_stride ) {
|
||||
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
__syncthreads();
|
||||
}
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l >= 0)) {
|
||||
// printf("p_prev_lse=%.6f, %.6f\n", p_prev_lse[0], p_prev_lse[1]);
|
||||
// }
|
||||
// }
|
||||
// Compute the max.
|
||||
float p_max[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.store_pair(p_prev_lse);
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi]; }
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1f; }
|
||||
}
|
||||
|
||||
// Trigger the load for the next LSE values.
|
||||
if (l + step_stride < steps) {
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse),
|
||||
step_stride);
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
|
||||
|
||||
// if ((threadIdx.x == 0) && (l == 38)) {
|
||||
// printf("loop_step_idx %d, p_max = %.6f, %.6f., p_prev_lse = %.6f, %.6f\n", loop_step_idx, p_max[0], p_max[1], Is_first ? -10000.f : p_prev_lse[0], Is_first ? -10000.f : p_prev_lse[1]);
|
||||
// }
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after reduce_max=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the exponential value.
|
||||
// softmax.apply_exp(p_max);
|
||||
softmax.scale_apply_exp(p_max, params.scale_bmm1f);
|
||||
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("after apply_exp=%.6f, %.6f\n", softmax.elt_[0][0], softmax.elt_[0][1]);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Compute the sum.
|
||||
float p_sum[Mma_tile_p::MMAS_M * 2];
|
||||
// if (!Is_first) {
|
||||
// int warp = tidx / Cta_tile_p::THREADS_PER_WARP;
|
||||
// int lane = tidx % Cta_tile_p::THREADS_PER_WARP;
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) {
|
||||
// p_sum[mi] = ((warp == 0) && (lane % 4 == 0)) ? expf(p_prev_lse[mi] - p_max[mi]) : 0;
|
||||
// }
|
||||
// }
|
||||
// softmax.reduce_sum(p_sum);
|
||||
softmax.reduce_sum_before_sync_(p_sum);
|
||||
// softmax.template reduce_sum_before_sync_</*zero_init=*/Is_first>(p_sum);
|
||||
|
||||
// float p_sum_log[Mma_tile_p::MMAS_M * 2];
|
||||
// for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; ++mi) {
|
||||
// float sum = p_sum[mi];
|
||||
// // p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] + __logf(sum);
|
||||
// constexpr float kLog2e = M_LOG2E;
|
||||
// p_sum_log[mi] = (sum == 0.f || sum != sum) ? INFINITY : p_max[mi] * kLog2e + __log2f(sum);
|
||||
// }
|
||||
// // gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum));
|
||||
// gmem_softmax_lse.store(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_sum_log));
|
||||
// gmem_softmax_lse.move();
|
||||
|
||||
// // Finalize softmax on the accumulators of P^T.
|
||||
// softmax.scale(p_sum);
|
||||
|
||||
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
|
||||
if (Is_dropout) {
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint);
|
||||
// softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, ph1, params.p_dropout_in_uint16_t);
|
||||
unsigned int warp_idx = threadIdx.x / 32;
|
||||
// TODO: this should change after we rearrange the warps (e.g. cutlass branch)
|
||||
unsigned int block_col_idx = loop_step_idx * Cta_tile_p::N / 16 + warp_idx;
|
||||
// We want to use actual_seqlen_k, not seqlen_k, since seqlen_k could be rounded
|
||||
// differently in the fwd and bwd pass. E.g., for d=128 on A100, fwd rounds seqlen_k
|
||||
// to multiples of 256 while bwd rounds seqlen_k to multiples of 128.
|
||||
unsigned long long philox_subsequence = (begin + l) * (binfo.actual_seqlen_k / 16) + block_col_idx;
|
||||
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph, params.p_dropout_in_uint16_t, philox_subsequence);
|
||||
}
|
||||
|
||||
using Frag_p = fmha::Fragment_a<fmha::Row>;
|
||||
Frag_p frag_p[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M];
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.template pack<elem_type>(frag_p);
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(frag_p, mask);
|
||||
gmem_s.move(step_stride);
|
||||
}
|
||||
|
||||
// Commit the values for Q into shared memory.
|
||||
if (l + step_stride < steps) {
|
||||
gmem_q.commit(gemm_q_k.smem_q);
|
||||
}
|
||||
|
||||
if (Is_dropout && encode_dropout_in_sign_bit) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ki++ ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < Mma_tile_o::MMAS_M; mi++ ) {
|
||||
frag_p[ki][mi].template hrelu_<elem_type>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
fmha::Fragment_accumulator acc_o[Mma_tile_o::MMAS_M][Mma_tile_o::MMAS_N];
|
||||
fmha::Clear_accumulator<typename fmha::Accumulator_type, Cta_tile_o::WARPS_K>::apply(acc_o);
|
||||
|
||||
// Do this part of O = P^T * V^T.
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < Mma_tile_o::MMAS_K; ++ki ) {
|
||||
fmha::gemm_cl<elem_type>(acc_o, frag_p[ki], frag_v[ki]);
|
||||
// if ((threadIdx.x == 4) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// float2 tmp_p = __half22float2(reinterpret_cast<__half2 &>(frag_p[ki]));
|
||||
// float2 tmp_v = __half22float2(reinterpret_cast<__half2 &>(frag_v[ki]));
|
||||
// printf("Per warp, threadIdx.x = %d, frag_p = %.6f, %.6f, frag_v = %.6f, %.6f, acc_o=%.6f\n", threadIdx.x, tmp_p.x, tmp_p.y, tmp_v.x, tmp_v.y, acc_o[0][0].elt(0));
|
||||
// }
|
||||
}
|
||||
|
||||
// if ((threadIdx.x % 32 == 16) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("Per warp, threadIdx.x = %d, acc_o=%.6f\n", threadIdx.x, acc_o[0][2].elt(0));
|
||||
// }
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the
|
||||
// O-reduction. So we recalculate the max.
|
||||
float p_max_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
int rows[Gmem_tile_o::STGS_PER_LOOP];
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
rows[jj] = tidx / Gmem_tile_o::THREADS_PER_ROW + jj * Gmem_tile_o::ROWS_PER_STG;
|
||||
}
|
||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_max_o[jj][0] *= params.scale_bmm1f;
|
||||
}
|
||||
float p_prev_scale_o[Gmem_tile_o::STGS_PER_LOOP];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.load(p_prev_scale_o, rows);
|
||||
}
|
||||
// if (!Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_prev_scale_o=%.6f\n", p_prev_scale_o[0]);
|
||||
// }
|
||||
// }
|
||||
|
||||
static_assert(Gmem_tile_o::LOOPS == 1);
|
||||
|
||||
// Swizzle the elements and do the final reduction.
|
||||
smem_o.store(acc_o, 0);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
float p_sum_o[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||
if (!Is_first) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
||||
p_sum_o[jj][0] += p_prev_scale_o[jj];
|
||||
}
|
||||
}
|
||||
|
||||
float p_sum_log[Gmem_tile_o::STGS_PER_LOOP][Mma_tile_o::MMAS_M];
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
|
||||
// if (sum == 0.f || sum != sum) {
|
||||
// printf("loop_step_idx = %d, l = %d, tidx = %d, sum = %.6f, p_max_o = %.6f\n", loop_step_idx, l, tidx, sum, p_max_o[jj][0]);
|
||||
// }
|
||||
// if (Is_first) {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (l == 0)) {
|
||||
// printf("p_sum_log=%.6f\n", p_sum_log[jj][0]);
|
||||
// }
|
||||
// }
|
||||
if (tidx % Gmem_tile_o::THREADS_PER_ROW == 0) {
|
||||
gmem_softmax_lse.store_row(
|
||||
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
||||
}
|
||||
}
|
||||
gmem_softmax_lse.move(step_stride);
|
||||
|
||||
// Load from shared memory.
|
||||
if (!Is_first) {
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
out[jj] = fmha::fmul4(out[jj], p_prev_scale_o[jj]);
|
||||
}
|
||||
}
|
||||
smem_o.template load</*zero_init=*/Is_first>(out);
|
||||
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
if (Is_dropout && is_final_write) {
|
||||
inv_sum *= params.rp_dropout;
|
||||
}
|
||||
out[jj] = fmha::fmul4(out[jj], inv_sum);
|
||||
}
|
||||
|
||||
// if (Is_dropout && Is_last) {
|
||||
// for (int jj = 0; jj < Gmem_tile_o::STGS_PER_LOOP; jj++) {
|
||||
// out[jj] = fmha::fmul4(out[jj], params.rp_dropout);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
gmem_o.template store<elem_type>(out, 0);
|
||||
gmem_o.move(step_stride);
|
||||
} else {
|
||||
gmem_o_tmp.store(out, 0);
|
||||
}
|
||||
|
||||
// Move to the next part of the output.
|
||||
if (!(Is_first && Is_last)) { gmem_o_tmp.move(step_stride); }
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
// Make sure we are reading from the correct buffer.
|
||||
gemm_q_k.smem_q.move_to_next_read_buffer();
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if (l + step_stride < steps) {
|
||||
gemm_q_k.reload_q();
|
||||
}
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
|
||||
inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
// We want the fwd and bwd to generate the same dropout pattern (RNG), without restricting
|
||||
// them to have the same number of threads or have to traverse the attention matrix
|
||||
// in the same order.
|
||||
// In the Philox RNG, we use the offset to store the batch, head, and the lane id
|
||||
// (within a warp). We use the subsequence to store the location of the 16 x 16 blocks within
|
||||
// the attention matrix. This way, as long as we have the batch, head, and the location of
|
||||
// the 16 x 16 block within the attention matrix, we can generate the exact same dropout pattern.
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
if (params.philox_args.captured_) {
|
||||
*params.seed = std::get<0>(seeds);
|
||||
*params.extragraph_offset = std::get<1>(seeds);
|
||||
}
|
||||
|
||||
Philox ph(std::get<0>(seeds), 0, std::get<1>(seeds) + (bidb * params.h + bidh) * 32 + tidx % 32);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
const int STEPS = (params.seqlen_q + M - 1) / M;
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, STEPS, ph, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, STEPS, ph, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, STEPS, ph, loop_step_idx);
|
||||
}
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, STEPS, ph, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,16 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_fwd_hdim128(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,21 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_fwd_hdim32(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
} else if (launch_params.params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 32, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,21 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fwd_launch_template.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
void run_fmha_fwd_hdim64(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
FP16_SWITCH(launch_params.params.is_bf16, ([&] {
|
||||
if (launch_params.params.seqlen_k == 128) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
} else if (launch_params.params.seqlen_k >= 256) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_fwd_loop<Kernel_traits>(launch_params);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,96 +0,0 @@
|
||||
// Copyright (c) 2022, Tri Dao.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h>
|
||||
|
||||
namespace pytorch_fmha {
|
||||
|
||||
// Find the number of splits that maximizes the occupancy. For example, if we have
|
||||
// batch * n_heads = 48 and we have 108 SMs, having 2 splits (efficiency = 0.89) is
|
||||
// better than having 3 splits (efficiency = 0.67). However, we also don't want too many
|
||||
// splits as that would incur more HBM reads/writes.
|
||||
// So we find the best efficiency, then find the smallest number of splits that gets 95%
|
||||
// of the best efficiency.
|
||||
// [2022-11-25] TD: Mark this as "inline" otherwise we get "multiple definition" error.
|
||||
inline int num_splits_heuristic_fwd(int batch_nheads, int num_SMs, int ctas_per_sm, int max_splits) {
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
float n_waves = float(batch_nheads * num_splits) / (num_SMs * ctas_per_sm);
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if (eff > max_efficiency) { max_efficiency = eff; }
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
for (int num_splits = 1; num_splits <= max_splits; num_splits++) {
|
||||
if (efficiency[num_splits - 1] > 0.95 * max_efficiency) {
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
return num_splits;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
|
||||
__global__ void fmha_fwd_loop_kernel(FMHA_fprop_params params) {
|
||||
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_fwd_loop(Launch_params<FMHA_fprop_params> &launch_params) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
|
||||
constexpr int smem_size_softmax_lse = Kernel_traits::Smem_dp_sum::BYTES_PER_TILE;
|
||||
// Don't need smem_size_softmax_lse if we're not looping
|
||||
const int smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, ([&] {
|
||||
auto kernel = launch_params.params.is_causal
|
||||
? (launch_params.return_softmax
|
||||
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
|
||||
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
|
||||
: (launch_params.return_softmax
|
||||
? &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
|
||||
: &fmha_fwd_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
|
||||
if( smem_size >= 48 * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
// Automatically set num_splits to maximize occupancy
|
||||
if (launch_params.params.num_splits <= 0) {
|
||||
int ctas_per_sm;
|
||||
cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&ctas_per_sm, kernel, Kernel_traits::THREADS, smem_size);
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
// printf("CTAS_PER_SM = %d, nSMs = %d\n", ctas_per_sm, dprops->multiProcessorCount);
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
launch_params.params.num_splits = num_splits_heuristic_fwd(
|
||||
launch_params.params.b * launch_params.params.h, dprops->multiProcessorCount,
|
||||
ctas_per_sm,
|
||||
/*max_splits=*/std::min(30, (launch_params.params.seqlen_q + M - 1 / M))
|
||||
);
|
||||
}
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h, launch_params.params.num_splits);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
}));
|
||||
}
|
||||
|
||||
}; // namespace pytorch_fmha
|
||||
@ -1,79 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/smem_tile.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gmem_tile.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS_PER_CTA>
|
||||
struct BlockInfoPadded {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfoPadded(const Params ¶ms,
|
||||
const int bidb,
|
||||
const int bidh,
|
||||
const int tidx)
|
||||
: bidb(bidb), bidh(bidh), h(params.h) {
|
||||
|
||||
// The block index.
|
||||
sum_s_k = params.cu_seqlens_k[bidb];
|
||||
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
|
||||
sum_s_q = params.cu_seqlens_q[bidb];
|
||||
actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
|
||||
|
||||
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
|
||||
}
|
||||
|
||||
__device__ bool stop_early(const int start_col = 0) const {
|
||||
return actual_seqlen_k <= start_col;
|
||||
}
|
||||
|
||||
int actual_seqlen_q;
|
||||
int actual_seqlen_k;
|
||||
int sum_s_q;
|
||||
int sum_s_k;
|
||||
int bidh;
|
||||
int bidb;
|
||||
int tidx_global;
|
||||
int h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,100 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define FMHA_CHECK_CUDA( call ) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if( status_ != cudaSuccess ) { \
|
||||
fprintf( stderr, \
|
||||
"CUDA error (%s:%d): %s\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString( status_ ) ); \
|
||||
exit( 1 ); \
|
||||
} \
|
||||
} while( 0 )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
enum Data_type { DATA_TYPE_FP16, DATA_TYPE_BF16, DATA_TYPE_FP32, DATA_TYPE_INT32, DATA_TYPE_INT8 };
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline void set_alpha( uint32_t &alpha, float norm, Data_type dtype ) {
|
||||
if( dtype == DATA_TYPE_FP16 ) {
|
||||
half x = __float2half_rn( norm );
|
||||
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
||||
ushort2 h2 = { h, h };
|
||||
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
||||
} else if( dtype == DATA_TYPE_BF16 ) {
|
||||
__nv_bfloat16 x = __float2bfloat16( norm );
|
||||
uint16_t h = reinterpret_cast<const uint16_t &>( x );
|
||||
ushort2 h2 = { h, h };
|
||||
alpha = reinterpret_cast<const uint32_t &>( h2 );
|
||||
} else if( dtype == DATA_TYPE_FP32 ) {
|
||||
alpha = reinterpret_cast<const uint32_t &>( norm );
|
||||
} else if( dtype == DATA_TYPE_INT32 ) {
|
||||
int32_t inorm = static_cast<int32_t>( norm );
|
||||
alpha = reinterpret_cast<const uint32_t &>( inorm );
|
||||
} else {
|
||||
assert( false );
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline size_t get_size_in_bytes( size_t n, Data_type dtype ) {
|
||||
switch( dtype ) {
|
||||
case DATA_TYPE_FP32:
|
||||
return n * 4;
|
||||
case DATA_TYPE_FP16:
|
||||
return n * 2;
|
||||
case DATA_TYPE_BF16:
|
||||
return n * 2;
|
||||
case DATA_TYPE_INT32:
|
||||
return n * 4;
|
||||
case DATA_TYPE_INT8:
|
||||
return n;
|
||||
default:
|
||||
assert( false );
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -1,452 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/gemm/warp/default_mma_tensor_op.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/arch/mma.h>
|
||||
#include <cutlass/array.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Data_type_, int NUM_ELTS_, int BITS_PER_ELT_, int ALIGNMENT_ >
|
||||
struct Fragment_base_ {
|
||||
|
||||
// The data type.
|
||||
using Data_type = Data_type_;
|
||||
// default input type
|
||||
using Input_type_ = Data_type_;
|
||||
// Does it store the array of elements.
|
||||
static constexpr bool HAS_ELTS = BITS_PER_ELT_ >= 8;
|
||||
// The number of elements.
|
||||
static constexpr int NUM_ELTS = NUM_ELTS_;
|
||||
// The size of element in bits.
|
||||
static constexpr int BITS_PER_ELT = BITS_PER_ELT_;
|
||||
// The size of byte of a single register.
|
||||
static constexpr int BYTES_PER_REG = 4;
|
||||
// The size in bits.
|
||||
static constexpr int BITS_PER_REG = BYTES_PER_REG * 8;
|
||||
// The number of registers needed to store the fragment.
|
||||
static constexpr int NUM_REGS = DivUpConstexpr(NUM_ELTS * BITS_PER_ELT, BITS_PER_REG);
|
||||
// The size in bytes (as returned by sizeof(Fragment_base<>).
|
||||
static constexpr int SIZE_IN_BYTES = NUM_REGS * BYTES_PER_REG;
|
||||
// The alignment.
|
||||
static constexpr int ALIGNMENT = ALIGNMENT_ > 0 ? ALIGNMENT_ : MinConstexpr(NUM_REGS * BYTES_PER_REG, 16);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The type of the elements.
|
||||
typename Data_type_,
|
||||
// The number of elements.
|
||||
int NUM_ELTS_,
|
||||
// The alignment if you want to force a value -- use 0 otherwise.
|
||||
int ALIGNMENT_ = 0,
|
||||
// The base class.
|
||||
typename Base_ = Fragment_base_<Data_type_, NUM_ELTS_, 8 * sizeof(Data_type_), ALIGNMENT_>
|
||||
>
|
||||
struct alignas(static_cast<int>(Base_::ALIGNMENT)) Fragment : public Base_ {
|
||||
|
||||
// The size of a load/store.
|
||||
static constexpr int BYTES_PER_LOAD_STORE = Base_::NUM_REGS * sizeof(uint32_t);
|
||||
|
||||
// Clear the fragment. Using PTX in that code seems to produce better SASS...
|
||||
inline __device__ void clear() {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
asm volatile("mov.u32 %0, 0; \n" : "=r"(this->reg(ii)) : );
|
||||
}
|
||||
}
|
||||
|
||||
// Immutable access to a register.
|
||||
inline __device__ const uint32_t& reg(int ii) const {
|
||||
return this->regs_[ii];
|
||||
}
|
||||
|
||||
// Mutable access to a register.
|
||||
inline __device__ uint32_t& reg(int ii) {
|
||||
return this->regs_[ii];
|
||||
}
|
||||
|
||||
uint32_t regs_[Base_::NUM_REGS];
|
||||
|
||||
// Immutable access to the elements.
|
||||
inline __device__ const Data_type_& elt(int ii) const {
|
||||
return reinterpret_cast<const Data_type_*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Mutable access to the elements.
|
||||
inline __device__ Data_type_& elt(int ii) {
|
||||
return reinterpret_cast<Data_type_*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Immutable access to the elements with a cast.
|
||||
template< typename Cast_type >
|
||||
inline __device__ const Cast_type& elt_as(int ii) const {
|
||||
return reinterpret_cast<const Cast_type*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Mutable access to the elements.
|
||||
template< typename Cast_type >
|
||||
inline __device__ Cast_type& elt_as(int ii) {
|
||||
return reinterpret_cast<Cast_type*>(&this->regs_[0])[ii];
|
||||
}
|
||||
|
||||
// Add another fragment.
|
||||
inline __device__ void add(const Fragment &other) {
|
||||
// TODO (TD 2022-04-09): Shouldn't this be NUM_REGS instead of NUM_ELTS?
|
||||
// Also are we doing int addition or __half2 addition?
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < NUM_ELTS_; ++ii ) {
|
||||
this->elt(ii) += other.elt(ii);
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply by another fragment.
|
||||
inline __device__ void hmul(const Fragment &other) {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
this->reg(ii) = fmha::hmul2(this->reg(ii), other.reg(ii));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename elem_type>
|
||||
inline __device__ void hrelu_() {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < Base_::NUM_REGS; ++ii ) {
|
||||
this->reg(ii) = fmha::hrelu2<elem_type>(this->reg(ii));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Layout >
|
||||
struct Fragment_a : public Fragment<uint16_t, 8> {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Layout >
|
||||
struct Fragment_b : public Fragment<uint16_t, 8> {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Fragment_accumulator : public Fragment<float, 8> {
|
||||
|
||||
// The base class.
|
||||
using Base = Fragment<float, 8>;
|
||||
|
||||
// Add two fragments.
|
||||
template< typename Other_fragment_ >
|
||||
inline __device__ void add(const Other_fragment_ &other) {
|
||||
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
||||
this->elt(ii) = this->elt(ii) + other.elt(ii);
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void mul_(const float other) {
|
||||
for( int ii = 0; ii < Base::NUM_ELTS; ++ii ) {
|
||||
this->elt(ii) *= other;
|
||||
}
|
||||
}
|
||||
|
||||
// Do the HMMA.
|
||||
template< typename Layout_a, typename Layout_b >
|
||||
inline __device__ void mma(const Fragment_a<Layout_a> &a,
|
||||
const Fragment_b<Layout_b> &b) {
|
||||
asm volatile( \
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
||||
" {%0, %1, %2, %3}, \n" \
|
||||
" {%4, %5, %6, %7}, \n" \
|
||||
" {%8, %9}, \n" \
|
||||
" {%0, %1, %2, %3}; \n" \
|
||||
: "+f"( elt(0)), "+f"( elt(1)), "+f"( elt(2)), "+f"( elt(3))
|
||||
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
||||
, "r"(b.reg(0)), "r"(b.reg(1)));
|
||||
asm volatile( \
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 \n" \
|
||||
" {%0, %1, %2, %3}, \n" \
|
||||
" {%4, %5, %6, %7}, \n" \
|
||||
" {%8, %9}, \n" \
|
||||
" {%0, %1, %2, %3}; \n" \
|
||||
: "+f"( elt(4)), "+f"( elt(5)), "+f"( elt(6)), "+f"( elt(7))
|
||||
: "r"(a.reg(0)), "r"(a.reg(1)), "r"(a.reg(2)), "r"(a.reg(3))
|
||||
, "r"(b.reg(2)), "r"(b.reg(3)));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Fragment, int M, int N >
|
||||
inline __device__ void clear(Fragment (&frag)[M][N]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ++ni ) {
|
||||
frag[mi][ni].clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Accumulator_type, int WARPS_K >
|
||||
struct Clear_accumulator {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int WARPS_K >
|
||||
struct Clear_accumulator<float, WARPS_K> {
|
||||
template< typename Acc, int M, int N >
|
||||
static inline __device__ void apply(Acc (&acc)[M][N], bool = false) {
|
||||
fmha::clear(acc);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Acc, typename A, typename B, int M, int N>
|
||||
inline __device__ void gemm(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ++ni ) {
|
||||
acc[mi][ni].mma(a[mi], b[ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
/// Statically maps half types => cutlass data types
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename Type_>
|
||||
struct HalfTypeToCutlassType { using Type = Type_; };
|
||||
|
||||
/// Statically maps __half => cutlass::half_t
|
||||
template <> struct HalfTypeToCutlassType<__half> {
|
||||
using Type = cutlass::half_t;
|
||||
};
|
||||
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
template <> struct HalfTypeToCutlassType<__nv_bfloat16> {
|
||||
using Type = cutlass::bfloat16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename elem_type, typename Acc, typename A, typename B, int M, int N>
|
||||
inline __device__ void gemm_cl(Acc (&acc)[M][N], const A (&a)[M], const B (&b)[N]) {
|
||||
using Shape = cutlass::gemm::GemmShape<16 * M, 16 * N, 16>;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
#else
|
||||
assert(0);
|
||||
// THIS IS NOT CORRECT BUT THE ASSERT WILL STOP THIS
|
||||
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
// TD [2022-06-02] We don't support Volta (SM70) yet.
|
||||
#endif
|
||||
using Element = typename HalfTypeToCutlassType<elem_type>::Type;
|
||||
using ElementC = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
|
||||
using WarpMma = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
Shape, InstructionShape, Element, LayoutA, Element, LayoutB, ElementC,
|
||||
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAdd, 1, true>::Type;
|
||||
|
||||
constexpr int kIters = Shape::kK / InstructionShape::kK;
|
||||
// using FragmentA = typename WarpMma::FragmentA;
|
||||
// using FragmentB = typename WarpMma::FragmentB;
|
||||
using FragmentA = typename WarpMma::ArchMmaOperator::FragmentA;
|
||||
using FragmentB = typename WarpMma::ArchMmaOperator::FragmentB;
|
||||
using FragmentC = typename WarpMma::FragmentC;
|
||||
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y) == 0) {
|
||||
// printf("FragmentA::kStorageElements = %d\n", FragmentA::kStorageElements);
|
||||
// printf("Archmma::FragmentA::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentA::kStorageElements);
|
||||
// printf("FragmentB::kStorageElements = %d\n", FragmentB::kStorageElements);
|
||||
// printf("Archmma::FragmentB::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentB::kStorageElements);
|
||||
// printf("FragmentC::kStorageElements = %d\n", FragmentC::kStorageElements);
|
||||
// printf("Archmma::FragmentC::kStorageElements = %d\n", WarpMma::ArchMmaOperator::FragmentC::kStorageElements);
|
||||
// }
|
||||
|
||||
// static_assert(FragmentA::kStorageElements == M * a[0].NUM_REGS);
|
||||
// static_assert(FragmentB::kStorageElements == N * b[0].NUM_REGS);
|
||||
static_assert(FragmentA::kStorageElements * kIters == a[0].NUM_REGS);
|
||||
static_assert(FragmentB::kStorageElements * kIters * 16 / InstructionShape::kN == b[0].NUM_REGS);
|
||||
static_assert(FragmentC::kStorageElements == M * N * acc[0][0].NUM_REGS);
|
||||
// const FragmentA a_cl = reinterpret_cast<const FragmentA (&)>(a);
|
||||
// const FragmentB b_cl = reinterpret_cast<const FragmentB (&)>(b);
|
||||
FragmentC c_cl = reinterpret_cast<FragmentC (&)>(acc);
|
||||
FragmentA a_cl[kIters][M];
|
||||
FragmentA b_cl[kIters][N];
|
||||
constexpr int kRegs = InstructionShape::kK == 16 ? 4 : 2;
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; mi++) {
|
||||
uint32_t *a_ptr = a_cl[iter][mi].raw_data();
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < kRegs; ki++) {
|
||||
a_ptr[ki] = a[mi].regs_[iter * kRegs + ki];
|
||||
}
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ni++) {
|
||||
uint32_t *b_ptr = b_cl[iter][ni].raw_data();
|
||||
#pragma unroll
|
||||
for (int ki = 0; ki < kRegs; ki++) {
|
||||
// b_ptr[ki] = b[ni].regs_[iter * kRegs + ki];
|
||||
// TD [2022-06-02] For some reason the order for frag_b is different.
|
||||
b_ptr[ki] = b[ni].regs_[InstructionShape::kK == 16 ? iter * kRegs + ki : ki * kRegs + iter];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WarpMma mma_op;
|
||||
// mma_op(c_cl, a_cl, b_cl, c_cl);
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < kIters; iter++) {
|
||||
mma_op(c_cl, reinterpret_cast<const typename WarpMma::FragmentA (&)>(a_cl[iter]),
|
||||
reinterpret_cast<const typename WarpMma::FragmentB (&)>(b_cl[iter]), c_cl);
|
||||
}
|
||||
|
||||
// The modified c_cl is not copied back into acc, idk why
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < M; mi++) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ni++) {
|
||||
#pragma unroll
|
||||
for (int i =0; i < 8; i++) {
|
||||
acc[mi][ni].elt(i) = c_cl[mi * N * 8 + ni * 8 + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The number of rows in the CTA tile.
|
||||
int M_,
|
||||
// The number of cols in the CTA tile.
|
||||
int N_,
|
||||
// The number of elements in the the K dimension of the GEMM loop.
|
||||
int K_,
|
||||
// The number of rows of warps.
|
||||
int WARPS_M_,
|
||||
// The number of cols of warps.
|
||||
int WARPS_N_,
|
||||
// The number of warps in the K dimension of the GEMM loop.
|
||||
int WARPS_K_>
|
||||
struct Cta_tile_ {
|
||||
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
// The number of warps.
|
||||
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
|
||||
// The number of warps per CTA.
|
||||
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
|
||||
// The number of threads per warp.
|
||||
static constexpr int THREADS_PER_WARP = 32;
|
||||
// The number of threads per CTA.
|
||||
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile>
|
||||
struct Hmma_tile {
|
||||
// The number of elements computed with a single warp-MMA.
|
||||
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
|
||||
|
||||
// The number of elements computed with a single CTA-MMA.
|
||||
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
|
||||
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
|
||||
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
|
||||
|
||||
// The number of MMAs needed to compute the GEMM.
|
||||
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
|
||||
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
|
||||
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
|
||||
|
||||
// // The number of elements computed per warp.
|
||||
// static constexpr int M_PER_WARP = MMAS_M * M_PER_MMA,
|
||||
// N_PER_WARP = MMAS_N * N_PER_MMA,
|
||||
// K_PER_WARP = MMAS_K * K_PER_MMA;
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using A_type = uint16_t;
|
||||
using B_type = uint16_t;
|
||||
using C_type = uint16_t;
|
||||
using Accumulator_type = float;
|
||||
using Epilogue_type = float;
|
||||
|
||||
constexpr int BITS_PER_ELEMENT_A = sizeof(A_type) * 8;
|
||||
constexpr int BITS_PER_ELEMENT_B = sizeof(B_type) * 8;
|
||||
constexpr int BITS_PER_ELEMENT_C = sizeof(C_type) * 8;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
|
||||
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile_>
|
||||
using Cta_tile_with_k_with_padding = Cta_tile_extd<Cta_tile_::M,
|
||||
Cta_tile_::N,
|
||||
Next_power_of_two<Cta_tile_::K>::VALUE,
|
||||
Cta_tile_::WARPS_M,
|
||||
Cta_tile_::WARPS_N,
|
||||
Cta_tile_::WARPS_K>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,555 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
template<
|
||||
// The dimensions of the tile computed by the CTA.
|
||||
typename Cta_tile_,
|
||||
// The number of bits per element.
|
||||
int BITS_PER_ELEMENT,
|
||||
// The number of rows of Q, K or V loaded by this tile.
|
||||
int ROWS_,
|
||||
// The number of columns.
|
||||
int COLS,
|
||||
int BYTES_PER_LDGS_ = 16
|
||||
>
|
||||
struct Gmem_tile_qkv {
|
||||
|
||||
using Cta_tile = Cta_tile_;
|
||||
|
||||
static constexpr int BYTES_PER_ELEMENT = BITS_PER_ELEMENT / 8;
|
||||
// The size of each LDG.
|
||||
static constexpr int BYTES_PER_LDG = BYTES_PER_LDGS_;
|
||||
// The size of a row in bytes.
|
||||
static constexpr int BYTES_PER_ROW = COLS * BITS_PER_ELEMENT / 8;
|
||||
|
||||
// The number of threads to load a "row" of the matrix.
|
||||
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_LDG;
|
||||
|
||||
static constexpr int ROWS = ROWS_;
|
||||
// The number of "rows" loaded per LDG.
|
||||
static constexpr int ROWS_PER_LDG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
|
||||
// The number of LDGs needed to load a chunk of the Q matrix.
|
||||
static constexpr int LDGS = DivUpConstexpr(ROWS, ROWS_PER_LDG);
|
||||
|
||||
// Ctor.
|
||||
template< typename BInfo >
|
||||
inline __device__ Gmem_tile_qkv(void *ptr_, const uint32_t row_stride_in_elts,
|
||||
const uint32_t head_stride_in_elts, const int headdim,
|
||||
const BInfo &binfo, const int tidx, bool use_seqlen_q)
|
||||
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
|
||||
, actual_seqlen(use_seqlen_q ? binfo.actual_seqlen_q : binfo.actual_seqlen_k)
|
||||
, ptr(reinterpret_cast<char *>(ptr_))
|
||||
, tidx_(tidx)
|
||||
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_LDG / BYTES_PER_ELEMENT) < headdim) {
|
||||
|
||||
// Compute the position in the sequence (within the CTA for the moment).
|
||||
int row = tidx / THREADS_PER_ROW;
|
||||
// Compute the position of the thread in the row.
|
||||
int col = tidx % THREADS_PER_ROW;
|
||||
|
||||
// Store the row as we need it to disable the loads.
|
||||
// TD [2022-04-16]: To minimize registers, we'll recompute row_ instead of storing it
|
||||
// row_ = row;
|
||||
|
||||
// The row offset in the batched GEMM. For each seq element, we store QKV in that order.
|
||||
// int64_t row_offset = (int64_t)row * params.qkv_stride_in_bytes;
|
||||
uint32_t row_offset = (uint32_t)(((use_seqlen_q ? binfo.sum_s_q : binfo.sum_s_k) + row) * row_stride_in_bytes);
|
||||
// Add the block index.
|
||||
// row_offset += (int64_t)((binfo.sum_s * NUM_MATS + qkv_offset) * binfo.h + binfo.bidh) * BYTES_PER_ROW;
|
||||
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
|
||||
|
||||
// Assemble the final pointer.
|
||||
ptr += row_offset + col * BYTES_PER_LDG;
|
||||
}
|
||||
|
||||
// Store data to shared memory.
|
||||
template< typename Smem_tile >
|
||||
inline __device__ void commit(Smem_tile &smem_tile) {
|
||||
smem_tile.store(fetch_);
|
||||
}
|
||||
|
||||
inline __device__ void load() {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
const void *ptrs[LDGS];
|
||||
uint32_t preds[LDGS];
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
// ptrs[ii] = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
ptrs[ii] = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
preds[ii] = col_predicate && ((row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen));
|
||||
fetch_[ii] = make_uint4(0, 0, 0, 0);
|
||||
}
|
||||
|
||||
// not packing predicates removes restrictions (e.g. FP16 384, 4 warps)
|
||||
Ldg_functor<uint4, LDGS> fct(fetch_, ptrs);
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
fct.load(ii, preds[ii]);
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to memory.
|
||||
inline __device__ void store(const uint4 (&data)[LDGS]) {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < LDGS; ++ii ) {
|
||||
// char *ptr_ = ptr + (int64_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
char *ptr_ = ptr + (uint32_t)ii * ROWS_PER_LDG * row_stride_in_bytes;
|
||||
if (col_predicate && (row_ + ii * ROWS_PER_LDG) < min(ROWS, actual_seqlen)) {
|
||||
fmha::stg(ptr_, data[ii]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
// ptr += (int64_t)ROWS * row_stride_in_bytes * steps;
|
||||
ptr += (uint32_t)ROWS * row_stride_in_bytes * steps;
|
||||
actual_seqlen -= ROWS * steps;
|
||||
}
|
||||
|
||||
// The stride between rows for the QKV matrice.
|
||||
// int64_t row_stride_in_bytes;
|
||||
const uint32_t row_stride_in_bytes;
|
||||
// The pointer.
|
||||
char *ptr;
|
||||
// The fetch registers.
|
||||
uint4 fetch_[LDGS];
|
||||
// Keep track of the row the thread is processing as we move the tile.
|
||||
// int row_;
|
||||
const int tidx_;
|
||||
// The length of the sequence loaded by that memory tile.
|
||||
int actual_seqlen;
|
||||
const bool col_predicate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
typename Cta_tile,
|
||||
int BYTES_PER_ELEMENT = 2
|
||||
>
|
||||
struct Gmem_tile_o {
|
||||
|
||||
static_assert(BYTES_PER_ELEMENT == 2 || BYTES_PER_ELEMENT == 4);
|
||||
|
||||
// The mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The size of each element.
|
||||
// static constexpr int BYTES_PER_ELEMENT = 2;
|
||||
// The size of each STG.
|
||||
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 4;
|
||||
static constexpr int COLS = Cta_tile::N;
|
||||
// The size of a row in bytes.
|
||||
static constexpr int BYTES_PER_ROW = COLS * BYTES_PER_ELEMENT;
|
||||
|
||||
// The number of threads to store a "row" of the matrix.
|
||||
static constexpr int THREADS_PER_ROW = BYTES_PER_ROW / BYTES_PER_STG;
|
||||
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
||||
static constexpr int ROWS = Cta_tile::M;
|
||||
// The number of "rows" stored per iteration of the loop. The output of 1 MMA.
|
||||
static constexpr int ROWS_PER_LOOP = ROWS <= 64 ? ROWS : (int)Mma_tile::M_PER_MMA_PER_CTA;
|
||||
// The number of outter loop for the stores.
|
||||
static constexpr int LOOPS = ROWS / ROWS_PER_LOOP;
|
||||
|
||||
// The number of "rows" stored per STG.
|
||||
static constexpr int ROWS_PER_STG = Cta_tile::THREADS_PER_CTA / THREADS_PER_ROW;
|
||||
// Do we have to guard against partial writes/reads.
|
||||
static constexpr bool HAS_INCOMPLETE_STG = Cta_tile::M % ROWS_PER_STG != 0;
|
||||
// The number of STGs needed to store a chunk of the Q matrix.
|
||||
static constexpr int STGS_PER_LOOP = DivUpConstexpr(ROWS_PER_LOOP, ROWS_PER_STG);
|
||||
// The number of STGs needed to store a chunk of the Q matrix in total.
|
||||
static constexpr int STGS = STGS_PER_LOOP * LOOPS;
|
||||
|
||||
// Ctor.
|
||||
template<typename BInfo>
|
||||
// inline __device__ Gmem_tile_o(void *ptr, const size_t row_stride_in_elts, const BInfo &binfo, const int tidx)
|
||||
inline __device__ Gmem_tile_o(void *ptr, const uint32_t row_stride_in_elts,
|
||||
const uint32_t head_stride_in_elts, const int headdim,
|
||||
const BInfo &binfo, const int tidx)
|
||||
: row_stride_in_bytes(row_stride_in_elts * BYTES_PER_ELEMENT)
|
||||
, actual_seqlen_q(binfo.actual_seqlen_q)
|
||||
, ptr_(reinterpret_cast<char *>(ptr))
|
||||
, tidx_(tidx)
|
||||
, col_predicate((tidx % THREADS_PER_ROW) * (BYTES_PER_STG / BYTES_PER_ELEMENT) < headdim) {
|
||||
|
||||
// Compute the position in the sequence (within the CTA for the moment).
|
||||
int row = tidx / THREADS_PER_ROW;
|
||||
// Compute the position of the thread in the row.
|
||||
int col = tidx % THREADS_PER_ROW;
|
||||
|
||||
// Store the row as we need it to disable loads.
|
||||
// row_ = row;
|
||||
|
||||
// The row offset in the batched GEMM.
|
||||
// int64_t row_offset = (int64_t)row * row_stride_in_bytes + binfo.bidx * BYTES_PER_ROW;
|
||||
uint32_t row_offset = (uint32_t)((binfo.sum_s_q + row) * row_stride_in_bytes);
|
||||
row_offset += (uint32_t)(binfo.bidh * head_stride_in_elts * BYTES_PER_ELEMENT);
|
||||
// Assemble the final pointer.
|
||||
ptr_ += row_offset + col * BYTES_PER_STG;
|
||||
|
||||
// Is that thread active on the last STG?
|
||||
if( HAS_INCOMPLETE_STG ) {
|
||||
is_active_for_last_stg_ = row + (STGS - 1) * ROWS_PER_STG < Cta_tile::M;
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template<typename elem_type=__half>
|
||||
inline __device__ void store(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (BYTES_PER_ELEMENT == 4) {
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, src[ii]);
|
||||
}
|
||||
} else if (BYTES_PER_ELEMENT == 2) {
|
||||
float x = reinterpret_cast<const float &>(src[ii].x);
|
||||
float y = reinterpret_cast<const float &>(src[ii].y);
|
||||
float z = reinterpret_cast<const float &>(src[ii].z);
|
||||
float w = reinterpret_cast<const float &>(src[ii].w);
|
||||
uint2 out = fmha::float4_pack<elem_type>(x, y, z, w);
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::stg(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes, out);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory with atomicAdd.
|
||||
inline __device__ void atomic_add(const uint4 (&src)[STGS_PER_LOOP], int mi) {
|
||||
static_assert(BYTES_PER_ELEMENT == 4); // Only do atomic add on floats
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
float *ptr_ = reinterpret_cast<float *>(this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
atomicAdd(ptr_ + jj, reinterpret_cast<const float(&)[4]>(src[ii])[jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load data from global memory.
|
||||
inline __device__ void load(uint4 (&dst)[STGS_PER_LOOP], int mi) {
|
||||
static_assert(BYTES_PER_ELEMENT == 4);
|
||||
int row_ = tidx_ / THREADS_PER_ROW;
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < STGS_PER_LOOP; ++ii ) {
|
||||
int jj = mi * STGS_PER_LOOP + ii;
|
||||
if ((!col_predicate) || (row_ + jj * ROWS_PER_STG >= this->actual_seqlen_q)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if( !HAS_INCOMPLETE_STG || (jj < STGS - 1 || this->is_active_for_last_stg_) ) {
|
||||
fmha::ldg(dst[ii], this->ptr_ + jj * ROWS_PER_STG * this->row_stride_in_bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
// row_ += ROWS * steps;
|
||||
// ptr_ += (int64_t)ROWS * row_stride_in_bytes * steps;
|
||||
ptr_ += (uint32_t)ROWS * row_stride_in_bytes * steps;
|
||||
actual_seqlen_q -= ROWS * steps;
|
||||
}
|
||||
|
||||
// The stride between rows for the QKV matrice.
|
||||
// int64_t row_stride_in_bytes;
|
||||
const uint32_t row_stride_in_bytes;
|
||||
// The pointer.
|
||||
char *ptr_;
|
||||
// Is the thread active for the last STG?
|
||||
int is_active_for_last_stg_;
|
||||
// The length of the sequence loaded by that memory tile.
|
||||
int actual_seqlen_q;
|
||||
const int tidx_;
|
||||
const bool col_predicate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, int BYTES_PER_ELEMENT >
|
||||
struct Gmem_tile_mma_sd {
|
||||
|
||||
// The mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// Each STG stores 8 elements.
|
||||
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
|
||||
// The number of MMAs in the M dimension.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
// The number of MMAs in the N dimension.
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
// The number of rows computed per MMA per thread block.
|
||||
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
|
||||
// The number of cols computed per MMA per thread block.
|
||||
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
|
||||
// The number of threads per block.
|
||||
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
|
||||
// The size of each row in bytes. I.e. how many bytes are stored per STG.
|
||||
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
|
||||
// The distance between elements stored per loop (in bytes).
|
||||
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
|
||||
|
||||
// The type of elements stored per STG.
|
||||
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx)
|
||||
: ptr_(static_cast<char *>(ptr)) {
|
||||
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
inline __device__ void store(const Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::stg(ptr_ + offset, data);
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::ldg(data, ptr_ + offset);
|
||||
}
|
||||
|
||||
// Move to the next tile.
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
ptr_ += LOOP_STRIDE_BYTES * steps;
|
||||
}
|
||||
|
||||
// The pointer in global memory.
|
||||
char *ptr_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
|
||||
struct Gmem_tile_mma_s : public Base {
|
||||
|
||||
// The number of mmas in the vertical dimension.
|
||||
static constexpr int M = Base::MMAS_M;
|
||||
// The number of mmas in the horizontal dimension.
|
||||
static constexpr int N = Base::MMAS_N;
|
||||
// The type of the vectors stored by each STG.
|
||||
using Type = typename Base::Type;
|
||||
|
||||
// Ctor.
|
||||
template< typename Params, typename Block_info >
|
||||
inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx)
|
||||
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
template<typename Mask, typename Fragment>
|
||||
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
uint4 dst;
|
||||
dst.x = frag[ni][mi].reg(0);
|
||||
dst.y = frag[ni][mi].reg(2);
|
||||
dst.z = frag[ni][mi].reg(1);
|
||||
dst.w = frag[ni][mi].reg(3);
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::store(dst, mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
template<typename Mask>
|
||||
inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
regs[mi][ni] = make_uint4(0, 0, 0, 0);
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::load(regs[mi][ni], mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The dimensions of the tile computed by the CTA.
|
||||
typename Cta_tile
|
||||
>
|
||||
struct Gmem_summary_stats {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
|
||||
// The size of each element.
|
||||
static constexpr int BYTES_PER_ELEMENT = 4;
|
||||
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
|
||||
static constexpr int ROWS = Cta_tile::M;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx)
|
||||
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_row_ = ptr_ + bidx * block_stride_bytes;
|
||||
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
|
||||
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
|
||||
if ((warp == 0) && (lane % 4 == 0)) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
|
||||
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template <int N>
|
||||
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ++ni) {
|
||||
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move() {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move(const int steps) {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
}
|
||||
|
||||
// The pointer.
|
||||
char *ptr_;
|
||||
char *ptr_row_;
|
||||
const int tidx_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,121 +1,383 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gmem_tile.h>
|
||||
|
||||
#pragma once
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type_=__half>
|
||||
struct FMHA_kernel_traits {
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
// The CTA description for the 1st GEMM.
|
||||
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
|
||||
// The CTA description for the 2nd GEMM.
|
||||
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
// Do we use one buffer for K and V.
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
|
||||
// Do we keep K in registers.
|
||||
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
|
||||
// Do we keep V in registers.
|
||||
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
|
||||
namespace pytorch_flash{
|
||||
|
||||
// The global memory tile to load Q.
|
||||
using Gmem_tile_q = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
|
||||
using namespace cute;
|
||||
|
||||
// The shared memory tile to swizzle Q.
|
||||
// using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 1>;
|
||||
using Smem_tile_q = fmha::Smem_tile_a<Cta_tile_p, fmha::Row, Gmem_tile_q::BYTES_PER_LDG, 2>;
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits {
|
||||
|
||||
// The global memory tile to load K.
|
||||
using Gmem_tile_k = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_B, S, D>;
|
||||
// The shared memory tile to swizzle K.
|
||||
using Smem_tile_k = fmha::Smem_tile_b<Cta_tile_p, fmha::Col>;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
// The global memory tile to load V.
|
||||
using Gmem_tile_v = fmha::Gmem_tile_qkv<Cta_tile_o, fmha::BITS_PER_ELEMENT_B, S, D>;
|
||||
// The shared memory tile to swizzle V.
|
||||
using Smem_tile_v = fmha::Smem_tile_v<Cta_tile_o>;
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
// The global memory tile to store O.
|
||||
using Gmem_tile_o = fmha::Gmem_tile_o<Cta_tile_o>;
|
||||
// The shared memory tile for O.
|
||||
using Smem_tile_o = fmha::Smem_tile_o<Cta_tile_o>;;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
// The global memory tile to load/store S.
|
||||
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
// The shared memory tile to transpose S.
|
||||
using Smem_tile_st = fmha::Smem_tile_mma_transposed<Cta_tile_p>;
|
||||
// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
using Gmem_tile_do = fmha::Gmem_tile_qkv<Cta_tile_p, fmha::BITS_PER_ELEMENT_A, STEP, D>;
|
||||
|
||||
// // The global memory tile to store the accumulated dK and dV
|
||||
// // Hack: we set BYTES_PER_LDGS=32 to emulate the access pattern of dK and dV
|
||||
// // where there are 16 bits per lements and 16 bytes per load. In reality we won't
|
||||
// // be issue any load or store of size 32 bytes.
|
||||
// using Gmem_tile_dkv_accum = fmha::Gmem_tile_qkv<Cta_tile_o, 32, S, D, 32>;
|
||||
|
||||
// The global memory tile to store the softmax sum.
|
||||
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
|
||||
|
||||
// The shared memory tile to store dp sum.
|
||||
using Smem_dp_sum = fmha::Smem_tile_dp_sum<Gmem_tile_q, 2>;
|
||||
|
||||
using elem_type = elem_type_;
|
||||
|
||||
// Make sure the number of threads match.
|
||||
static_assert((int)Gmem_tile_o::THREADS_PER_ROW == (int)Smem_tile_o::THREADS_PER_ROW, "");
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
|
||||
// Make sure the number of threads matches both CTAs.
|
||||
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
// The amount of shared memory needed to load Q and K.
|
||||
static constexpr int BYTES_PER_SMEM_QK = Smem_tile_q::BYTES_PER_TILE + Smem_tile_k::BYTES_PER_TILE;
|
||||
// The extra amount of shared memory needed to load V.
|
||||
static constexpr int BYTES_PER_SMEM_V = SHARE_SMEM_FOR_K_AND_V ? 0u : Smem_tile_v::BYTES_PER_TILE;
|
||||
// The amount of shared memory needed for Q, K and V..
|
||||
static constexpr int BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V;
|
||||
// The amount of shared memory needed to load Q and store O.
|
||||
static constexpr int BYTES_PER_SMEM_QO = Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE;
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
using SmemLayoutAtomVtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomVtransposedNoSwizzle{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposedNoSwizzle{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
|
||||
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
|
||||
// No_double_buffer is another option to reduce smem usage, but will slow things down.
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_,
|
||||
int AtomLayoutMSdP_=1, int AtomLayoutNdKV=2, int AtomLayoutMdQ=2,
|
||||
bool Is_V_in_regs_=false, bool No_double_buffer_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_bwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Is_V_in_regs = Is_V_in_regs_;
|
||||
static constexpr bool No_double_buffer = No_double_buffer_;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
static constexpr int AtomLayoutMSdP = AtomLayoutMSdP_;
|
||||
static_assert(kNWarps % AtomLayoutMSdP == 0);
|
||||
static_assert(kNWarps % AtomLayoutNdKV == 0);
|
||||
static_assert(kNWarps % AtomLayoutMdQ == 0);
|
||||
|
||||
using TiledMmaSdP = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMSdP>, Int<kNWarps / AtomLayoutMSdP>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using TiledMmadKV = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutNdKV>, Int<kNWarps / AtomLayoutNdKV>, _1>>,
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using TiledMmadQ = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<AtomLayoutMdQ>, Int<kNWarps / AtomLayoutMdQ>, _1>>, // 2x4x1 or 4x2x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQdO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQdO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdO{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM / kNWarps>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
// SmemLayoutAtomQdO{},
|
||||
SmemLayoutAtomKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
|
||||
using SmemLayoutAtomKtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
using SmemLayoutAtomKtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomKtransposedNoSwizzle{}));
|
||||
using SmemLayoutKtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// Maybe the KtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutKtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomKtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockN>{})));
|
||||
// using SmemLayoutKtransposedNoSwizzle = decltype(SmemLayoutKtransposed{}.layout_fn());
|
||||
|
||||
// TODO: generalize to other values of kBlockN
|
||||
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
|
||||
// static constexpr int kPBlockN = kBlockN;
|
||||
static_assert(kBlockN >= 64);
|
||||
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
|
||||
static constexpr int kPBlockN = 64;
|
||||
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
|
||||
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
|
||||
static constexpr int kSwizzlePdS = 3;
|
||||
using SmemLayoutAtomPdS = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{},
|
||||
Layout<Shape<Int<kBlockM>, Int<kPBlockN>>,
|
||||
Stride<Int<kPBlockN>, _1>>{}));
|
||||
using SmemLayoutPdS = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdS{},
|
||||
make_shape(Int<kBlockM>{}, Int<kBlockN>{})));
|
||||
using SmemLayoutAtomPdStransposedNoSwizzle = Layout<Shape<Int<kPBlockN>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kPBlockN>>>;
|
||||
using SmemLayoutAtomPdStransposed = decltype(
|
||||
composition(Swizzle<kSwizzlePdS, 3, 3>{}, SmemLayoutAtomPdStransposedNoSwizzle{}));
|
||||
using SmemLayoutPdStransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposed{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutPdStransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomPdStransposedNoSwizzle{},
|
||||
make_shape(Int<kBlockN>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutPdStransposedNoSwizzle = decltype(SmemLayoutPdStransposed{}.layout_fn());
|
||||
using SmemCopyAtomPdS = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemLayoutAtomQdOtransposedNoSwizzle = Layout<Shape<Int<kBlockKSmem>, Int<kBlockM>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>;
|
||||
|
||||
using SmemLayoutAtomQdOtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{}, SmemLayoutAtomQdOtransposedNoSwizzle{}));
|
||||
using SmemLayoutQdOtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposed{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
using SmemLayoutQdOtransposedNoSwizzle = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQdOtransposedNoSwizzle{},
|
||||
make_shape(Int<kHeadDim>{}, Int<kBlockM>{})));
|
||||
// using SmemLayoutQdOtransposedNoSwizzle = decltype(SmemLayoutQdOtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomdKV = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdKV{},
|
||||
make_shape(Int<kBlockN>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdKV = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
using SmemLayoutAtomdQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutdQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomdQ{},
|
||||
make_shape(Int<kBlockM>{}, Int<kHeadDim>{})));
|
||||
using SmemCopyAtomdQ = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQdOCount = size(SmemLayoutQdO{}) * (No_double_buffer ? 2 : 3); // Double buffer for sQ
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemdSCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemPCount = size(SmemLayoutPdS{});
|
||||
static constexpr int kSmemdQCount = size(SmemLayoutdQ{});
|
||||
static constexpr int kSmemdPsumCount = kBlockM;
|
||||
static constexpr int kSmemQdOSize = kSmemQdOCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemdSSize = kSmemdSCount * sizeof(Element);
|
||||
static constexpr int kSmemPSize = kSmemPCount * sizeof(Element);
|
||||
static constexpr int kSmemdQSize = kSmemdQCount * sizeof(Element);
|
||||
static constexpr int kSmemdPsumSize = kSmemdPsumCount * sizeof(ElementAccum);
|
||||
static constexpr int kSmemSize = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + std::max(kSmemPSize, kSmemdQSize)));
|
||||
static constexpr int kSmemSize1colblock = kSmemQdOSize
|
||||
+ (!Is_V_in_regs
|
||||
? kSmemKVSize + kSmemdSSize + kSmemPSize
|
||||
: std::max(kSmemKVSize, kSmemKVSize / 2 + kSmemdSSize + kSmemPSize));
|
||||
static constexpr int kSmemSize1rowblock = kSmemQdOSize / 3 * 2 + kSmemKVSize / 2 * 3
|
||||
+ kSmemdSSize + kSmemPSize;
|
||||
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem
|
||||
// to affect speed in practice.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopydO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemTiledCopydQ = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per store
|
||||
using GmemLayoutAtomdQaccum = std::conditional_t<
|
||||
kBlockKSmem == 32,
|
||||
Layout<Shape <_32, _8>, // Thread layout, 8 threads per row
|
||||
Stride< _8, _1>>,
|
||||
Layout<Shape <_16, _16>, // Thread layout, 16 threads per row
|
||||
Stride< _16, _1>>
|
||||
>;
|
||||
using GmemTiledCopydQaccum = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
GmemLayoutAtomdQaccum{},
|
||||
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
|
||||
|
||||
using GmemTiledCopydQaccumAtomicAdd = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
|
||||
Layout<Shape <_8, _32>, // Thread layout, 8 threads per row
|
||||
Stride<_32, _1>>{},
|
||||
Layout<Shape < _1, _1>>{})); // Val layout, 1 val per store
|
||||
|
||||
// The amount of shared memory needed for Q, K, V and O.
|
||||
static constexpr int BYTES_PER_SMEM = fmha::MaxConstexpr(BYTES_PER_SMEM_QKV, BYTES_PER_SMEM_QO);
|
||||
// Make sure we have enough shared memory.
|
||||
static_assert(Smem_tile_q::BYTES_PER_TILE + Smem_tile_o::BYTES_PER_TILE <= BYTES_PER_SMEM, "");
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,163 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include <cute/algorithm/copy.hpp>
|
||||
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/layout/layout.h>
|
||||
#include <cutlass/numeric_types.h>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::half_t>
|
||||
struct Flash_kernel_traits_sm90 {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
static constexpr bool Has_cp_async = true;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
static constexpr bool Has_cp_async = false;
|
||||
#endif
|
||||
|
||||
using ElementAccum = float;
|
||||
using index_t = uint32_t;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MMA_Atom_Arch = std::conditional_t<
|
||||
std::is_same_v<elem_type, cutlass::half_t>,
|
||||
MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
|
||||
MMA_Atom<SM80_16x8x16_F32BF16BF16F32_TN>
|
||||
>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _1>>;
|
||||
#else
|
||||
using MMA_Atom_Arch = MMA_Atom<SM75_16x8x8_F32F16F16F32_TN>;
|
||||
using ValLayoutMNK = Layout<Shape<_1, _2, _2>>;
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using SmemCopyAtom = Copy_Atom<SM75_U32x4_LDSM_N, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<SM75_U16x8_LDSM_T, elem_type>;
|
||||
#else
|
||||
using SmemCopyAtom = Copy_Atom<DefaultCopy, elem_type>;
|
||||
using SmemCopyAtomTransposed = Copy_Atom<DefaultCopy, elem_type>;
|
||||
#endif
|
||||
};
|
||||
|
||||
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, bool Is_Q_in_regs_=false, bool Share_Q_K_smem_=false, typename elem_type=cutlass::half_t,
|
||||
typename Base=Flash_kernel_traits_sm90<kHeadDim_, kBlockM_, kBlockN_, kNWarps_, elem_type> >
|
||||
struct Flash_fwd_kernel_traits : public Base {
|
||||
using Element = typename Base::Element;
|
||||
using ElementAccum = typename Base::ElementAccum;
|
||||
using index_t = typename Base::index_t;
|
||||
static constexpr bool Has_cp_async = Base::Has_cp_async;
|
||||
using SmemCopyAtom = typename Base::SmemCopyAtom;
|
||||
using SmemCopyAtomTransposed = typename Base::SmemCopyAtomTransposed;
|
||||
|
||||
static constexpr bool Share_Q_K_smem = Share_Q_K_smem_;
|
||||
static constexpr bool Is_Q_in_regs = Is_Q_in_regs_ || Share_Q_K_smem;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int kNWarps = kNWarps_;
|
||||
static constexpr int kNThreads = kNWarps * 32;
|
||||
|
||||
static constexpr int kBlockM = kBlockM_;
|
||||
static constexpr int kBlockN = kBlockN_;
|
||||
static constexpr int kHeadDim = kHeadDim_;
|
||||
static_assert(kHeadDim % 32 == 0);
|
||||
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
|
||||
static constexpr int kBlockKGmem = kHeadDim % 128 == 0 ? 128 : (kHeadDim % 64 == 0 ? 64 : 32);
|
||||
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
|
||||
|
||||
using TiledMma = TiledMMA<
|
||||
typename Base::MMA_Atom_Arch,
|
||||
Layout<Shape<Int<kNWarps>,_1,_1>>, // 4x1x1 or 8x1x1 thread group
|
||||
typename Base::ValLayoutMNK>; // 1x2x1 or 1x2x2 value group for 16x16x16 MMA and LDSM
|
||||
|
||||
using SmemLayoutAtomQ = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockKSmem, using kHeadDim gives wrong results for d=128
|
||||
Layout<Shape<_8, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutQ = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutKV = decltype(tile_to_shape(
|
||||
SmemLayoutAtomQ{},
|
||||
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
|
||||
|
||||
using SmemLayoutAtomVtransposed = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
// This has to be kBlockN and not 8, otherwise we get wrong results for d=128
|
||||
Layout<Shape<Int<kBlockKSmem>, Int<kBlockN>>,
|
||||
Stride<_1, Int<kBlockKSmem>>>{}));
|
||||
using SmemLayoutVtransposed = decltype(tile_to_shape(
|
||||
SmemLayoutAtomVtransposed{},
|
||||
Shape<Int<kHeadDim>, Int<kBlockN>>{}));
|
||||
// Maybe the VtransposeNoSwizzle just needs to have the right shape
|
||||
// And the strides don't matter?
|
||||
using SmemLayoutVtransposedNoSwizzle = decltype(SmemLayoutVtransposed{}.layout_fn());
|
||||
|
||||
using SmemLayoutAtomO = decltype(
|
||||
composition(Swizzle<kSwizzle, 3, 3>{},
|
||||
Layout<Shape<Int<8>, Int<kBlockKSmem>>,
|
||||
Stride<Int<kBlockKSmem>, _1>>{}));
|
||||
using SmemLayoutO = decltype(tile_to_shape(
|
||||
SmemLayoutAtomO{},
|
||||
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
|
||||
using SmemCopyAtomO = Copy_Atom<DefaultCopy, elem_type>;
|
||||
|
||||
static constexpr int kSmemQCount = size(SmemLayoutQ{});
|
||||
static constexpr int kSmemKVCount = size(SmemLayoutKV{}) * 2;
|
||||
static constexpr int kSmemQSize = kSmemQCount * sizeof(Element);
|
||||
static constexpr int kSmemKVSize = kSmemKVCount * sizeof(Element);
|
||||
static constexpr int kSmemSize = Share_Q_K_smem ? std::max(kSmemQSize, kSmemKVSize) : kSmemQSize + kSmemKVSize;
|
||||
|
||||
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
|
||||
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
|
||||
// Using kBlockKSmem here is 6-10% faster than kBlockKGmem for d=128 because of bank conflicts.
|
||||
// For example, for d=128, smem is split into 2 "pages", each page takes care of columns
|
||||
// 0-63 and 64-127. If we have 16 threads per row for gmem read, when we write to smem,
|
||||
// thread 0 - 7 will write to the first page and thread 8 - 15 will write to the second page,
|
||||
// to the same banks.
|
||||
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
|
||||
using GmemLayoutAtom = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
|
||||
Stride<Int<kGmemThreadsPerRow>, _1>>;
|
||||
|
||||
// We use CACHEGLOBAL instead of CACHEALWAYS for both Q and K/V, since we won't be reading
|
||||
// from the same address by the same threadblock. This is slightly faster.
|
||||
using Gmem_copy_struct = std::conditional_t<
|
||||
Has_cp_async,
|
||||
SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>,
|
||||
DefaultCopy
|
||||
>;
|
||||
using GmemTiledCopyQKV = decltype(
|
||||
make_tiled_copy(Copy_Atom<Gmem_copy_struct, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
|
||||
using GmemTiledCopyO = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtom{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
static constexpr int kGmemThreadsPerRowP = kBlockN / kGmemElemsPerLoad;
|
||||
static_assert(kNThreads % kGmemThreadsPerRowP == 0, "kNThreads must be a multiple of kGmemThreadsPerRowP");
|
||||
using GmemLayoutAtomP = Layout<Shape <Int<kNThreads / kGmemThreadsPerRowP>, Int<kGmemThreadsPerRowP>>,
|
||||
Stride<Int<kGmemThreadsPerRowP>, _1>>;
|
||||
|
||||
using GmemTiledCopyP = decltype(
|
||||
make_tiled_copy(Copy_Atom<DefaultCopy, elem_type>{},
|
||||
GmemLayoutAtomP{},
|
||||
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
|
||||
|
||||
};
|
||||
} // namespace pytorch_flash
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 128>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim128<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 160>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim160<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 192>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim192<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 224>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim224<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 256>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim256<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 32>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim32<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 64>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim64<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::bfloat16_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::bfloat16_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_bwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_bwd_<cutlass::half_t, 96>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {
|
||||
run_mha_bwd_hdim96<cutlass::half_t>(params, stream, configure);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 160>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim160<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 192>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim192<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 224>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim224<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 32>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim32<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 64>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim64<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::bfloat16_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::bfloat16_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,14 @@
|
||||
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/flash_fwd_launch_template.h>
|
||||
namespace pytorch_flash{
|
||||
|
||||
template<>
|
||||
void run_mha_fwd_<cutlass::half_t, 96>(Flash_fwd_params ¶ms, cudaStream_t stream) {
|
||||
run_mha_fwd_hdim96<cutlass::half_t>(params, stream);
|
||||
}
|
||||
} // namespace pytorch_flash
|
||||
@ -0,0 +1,100 @@
|
||||
# This file is run to generate the kernel instantiations for the flash_attn kernels
|
||||
# They are written to several files in order to speed up compilation
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
DTYPE_MAP = {
|
||||
"fp16": "cutlass::half_t",
|
||||
"bf16": "cutlass::bfloat16_t",
|
||||
}
|
||||
|
||||
SM = [80] # Sm80 kernels support up to
|
||||
HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256]
|
||||
KERNEL_IMPL_TEMPLATE_FWD = """
|
||||
template<>
|
||||
void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{
|
||||
run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream);
|
||||
}}
|
||||
"""
|
||||
|
||||
KERNEL_IMPL_TEMPLATE_BWD = """
|
||||
template<>
|
||||
void run_mha_bwd_<{DTYPE}, {HEAD_DIM}>(Flash_bwd_params ¶ms, cudaStream_t stream, const bool configure) {{
|
||||
run_mha_bwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream, configure);
|
||||
}}
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Kernel:
|
||||
sm: int
|
||||
dtype: str
|
||||
head_dim: int
|
||||
direction: str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
if self.direction == "fwd":
|
||||
return KERNEL_IMPL_TEMPLATE_FWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
else:
|
||||
return KERNEL_IMPL_TEMPLATE_BWD.format(
|
||||
DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim
|
||||
)
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu"
|
||||
|
||||
|
||||
def get_all_kernels() -> List[Kernel]:
|
||||
for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM):
|
||||
for direction in ["fwd", "bwd"]:
|
||||
yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction)
|
||||
|
||||
|
||||
def write_kernel(kernel: Kernel, autogen_dir: Path) -> None:
|
||||
prelude = """
|
||||
// Copyright (c) 2023, Tri Dao.
|
||||
|
||||
// Splitting the different head dimensions to different files to speed up compilation.
|
||||
// This file is auto-generated. See "generate_kernels.py"\n
|
||||
"""
|
||||
include = f"#include <ATen/native/transformers/cuda/flash_attn/flash_{kernel.direction}_launch_template.h>\n"
|
||||
namespace = "namespace pytorch_flash{\n"
|
||||
namespace_end = "} // namespace pytorch_flash\n"
|
||||
(autogen_dir / kernel.filename).write_text(
|
||||
prelude + include + namespace + kernel.template + namespace_end
|
||||
)
|
||||
|
||||
|
||||
def main(output_dir: Optional[str]) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
for kernel in get_all_kernels():
|
||||
write_kernel(kernel, output_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate_kernels",
|
||||
description="Generate the flash_attention kernels template instantiations",
|
||||
)
|
||||
# Set an optional output directory
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="Where to generate the kernels "
|
||||
" will default to <ATen/native/transformers/cuda/flash_attn/kernels/> ",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.output_dir)
|
||||
@ -1,92 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
|
||||
template<typename Cta_tile, bool Is_causal=false>
|
||||
struct Mask {
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
template<typename BInfo>
|
||||
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
|
||||
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
|
||||
, loop_step_idx(loop_step_idx_) {
|
||||
|
||||
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
static_assert(Cta_tile::WARPS_K == 1, "");
|
||||
|
||||
// find the warp in the Cta tile
|
||||
const int warp_n = (warp / Cta_tile::WARPS_M);
|
||||
const int warp_m = (warp % Cta_tile::WARPS_M);
|
||||
// decompose warp into 8x4 tile
|
||||
const int quad = lane / 4;
|
||||
const int tid = (lane % 4) * 2;
|
||||
row = warp_m * 16 + quad;
|
||||
col = warp_n * 16 + tid;
|
||||
}
|
||||
|
||||
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
|
||||
|
||||
// ii and jj iterate over the 2x4 fragment
|
||||
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_row = row_offset + ii * 8;
|
||||
const bool col_valid = current_col < actual_seqlen_k;
|
||||
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
|
||||
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
|
||||
// bool all_valid = Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0) && (blockIdx.z == 1)) {
|
||||
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
|
||||
// }
|
||||
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
|
||||
// return row_valid && col_valid;
|
||||
}
|
||||
|
||||
//BERT Mask: if upper left is invalid, none are valid
|
||||
inline __device__ bool any_valid(const int mi, const int ni) const {
|
||||
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
|
||||
}
|
||||
|
||||
inline __device__ void load(const int it) {
|
||||
row_offset = it * Cta_tile::M + row;
|
||||
}
|
||||
int row_offset;
|
||||
|
||||
int row;
|
||||
int col;
|
||||
const int loop_step_idx;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
} // namespace fmha
|
||||
@ -1,10 +1,57 @@
|
||||
// Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/philox.cuh
|
||||
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
|
||||
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/8ca3c881db3e3510fcb7725389f6a0633c9b992c/torch/csrc/jit/tensorexpr/cuda_random.h
|
||||
#pragma once
|
||||
// Philox CUDA.
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace pytorch_flash{
|
||||
|
||||
struct ull2 {
|
||||
unsigned long long x;
|
||||
unsigned long long y;
|
||||
};
|
||||
|
||||
inline __device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
: "=l"(tmp)
|
||||
: "r"(a), "r"(b));
|
||||
res = (uint2*)(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox_single_round(const uint4 ctr, const uint2 key) {
|
||||
constexpr unsigned long kPhiloxSA = 0xD2511F53;
|
||||
constexpr unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline __device__ uint4 philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset) {
|
||||
constexpr unsigned long kPhilox10A = 0x9E3779B9;
|
||||
constexpr unsigned long kPhilox10B = 0xBB67AE85;
|
||||
uint2 key = reinterpret_cast<uint2&>(seed);
|
||||
uint4 counter;
|
||||
ull2 *tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset;
|
||||
tmp->y = subsequence;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter = philox_single_round(counter, key);
|
||||
key.x += (kPhilox10A);
|
||||
key.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = philox_single_round(counter, key);
|
||||
return output;
|
||||
}
|
||||
|
||||
} // namespace flash
|
||||
|
||||
namespace {
|
||||
|
||||
class Philox {
|
||||
@ -12,7 +59,10 @@ public:
|
||||
__device__ inline Philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset)
|
||||
: key(reinterpret_cast<const uint2&>(seed)) {
|
||||
: STATE(0)
|
||||
, seed_(seed)
|
||||
, offset_(offset)
|
||||
, key(reinterpret_cast<const uint2&>(seed)) {
|
||||
//key.x = (unsigned int)seed;
|
||||
//key.y = (unsigned int)(seed >> 32);
|
||||
//counter = make_uint4(0, 0, 0, 0);
|
||||
@ -21,6 +71,7 @@ public:
|
||||
//STATE = 0;
|
||||
//incr_n(offset / 4);
|
||||
|
||||
// key = reinterpret_cast<const uint2&>(seed);
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset / 4;
|
||||
tmp->y = subsequence;
|
||||
@ -28,72 +79,64 @@ public:
|
||||
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
|
||||
__device__ inline uint4 operator()() {
|
||||
uint4 counter_ = counter;
|
||||
uint2 key_ = key;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A);
|
||||
key_.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = single_round(counter_, key_);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// }
|
||||
incr();
|
||||
return output;
|
||||
}
|
||||
|
||||
__device__ inline uint4 operator()(const unsigned long long subsequence) {
|
||||
uint4 counter_ = counter;
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter_);
|
||||
tmp->y = subsequence;
|
||||
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("tidx = %d, counter_: %u, %u, %u, %u\n", threadIdx.x, counter_.x, counter_.y, counter_.z, counter_.w);
|
||||
// }
|
||||
uint2 key_ = key;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A);
|
||||
key_.y += (kPhilox10B);
|
||||
}
|
||||
uint4 output = single_round(counter_, key_);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// }
|
||||
return output;
|
||||
// // if (STATE == 0) {
|
||||
// uint4 counter_ = counter;
|
||||
// uint2 key_ = key;
|
||||
// // 7-round philox
|
||||
// #pragma unroll
|
||||
// for (int i = 0; i < 6; i++) {
|
||||
// counter_ = pytorch_flash::philox_single_round(counter_, key_);
|
||||
// key_.x += (kPhilox10A);
|
||||
// key_.y += (kPhilox10B);
|
||||
// }
|
||||
// // output = philox_single_round(counter_, key_);
|
||||
// uint4 output = pytorch_flash::philox_single_round(counter_, key_);
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// // printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// // }
|
||||
// incr();
|
||||
// // }
|
||||
// // return a float4 directly
|
||||
// // unsigned long ret;
|
||||
// // switch(STATE) {
|
||||
// // case 0: ret = output.x; break;
|
||||
// // case 1: ret = output.y; break;
|
||||
// // case 2: ret = output.z; break;
|
||||
// // case 3: ret = output.w; break;
|
||||
// //}
|
||||
// // STATE = (STATE + 1) % 4;
|
||||
// return output;
|
||||
return pytorch_flash::philox(seed_, offset_, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
unsigned long long offset_, seed_;
|
||||
struct ull2 {
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
// uint4 output;
|
||||
const uint2 key;
|
||||
unsigned int STATE;
|
||||
__device__ inline void incr_n(unsigned long long n) {
|
||||
unsigned int nlo = (unsigned int)(n);
|
||||
unsigned int nhi = (unsigned int)(n >> 32);
|
||||
counter.x += nlo;
|
||||
if (counter.x < nlo)
|
||||
nhi++;
|
||||
counter.y += nhi;
|
||||
if (nhi <= counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
|
||||
// __device__ inline void incr_n(unsigned long long n) {
|
||||
// unsigned int nlo = (unsigned int)(n);
|
||||
// unsigned int nhi = (unsigned int)(n >> 32);
|
||||
// counter.x += nlo;
|
||||
// if (counter.x < nlo)
|
||||
// nhi++;
|
||||
// counter.y += nhi;
|
||||
// if (nhi <= counter.y)
|
||||
// return;
|
||||
// if (++counter.z)
|
||||
// return;
|
||||
// ++counter.w;
|
||||
// }
|
||||
|
||||
__device__ uint4 incr(uint4 ctr) {
|
||||
__device__ uint4 incr128 (uint4 ctr)
|
||||
{
|
||||
uint4 res;
|
||||
asm ("add.cc.u32 %0, %4, %8;\n\t"
|
||||
"addc.cc.u32 %1, %5, %9;\n\t"
|
||||
@ -109,51 +152,16 @@ private:
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
counter = incr(counter);
|
||||
counter = incr128(counter);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
|
||||
// __device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
|
||||
// unsigned int *result_high) {
|
||||
// *result_high = __umulhi(a, b);
|
||||
// return a * b;
|
||||
// }
|
||||
|
||||
__device__ uint2 mulhilo32(const unsigned int a, const unsigned int b) {
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
: "=l"(tmp)
|
||||
: "r"(a), "r"(b));
|
||||
res = (uint2*)(&tmp);
|
||||
return *res;
|
||||
}
|
||||
|
||||
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
|
||||
//unsigned int hi0;
|
||||
//unsigned int hi1;
|
||||
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
|
||||
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
|
||||
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
|
||||
uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
// static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
// static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
|
||||
// Inverse of 2^32.
|
||||
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
|
||||
__device__ __inline__ float4 uniform4(const uint4 x) {
|
||||
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
|
||||
x.w * M_RAN_INVM32);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace pytorch_flash
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -31,579 +31,264 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
namespace fmha {
|
||||
namespace pytorch_flash {
|
||||
|
||||
using namespace cute;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp_(float x, float max) {
|
||||
return __expf(x - max);
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); mi++) {
|
||||
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
summary(mi) = op(summary(mi), tensor(mi, ni));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp2_(float x, float max) {
|
||||
return exp2f(x - max);
|
||||
// With fast-math, this produces the same PTX instruction as the assembly below
|
||||
// float diff = x - max;
|
||||
// float res;
|
||||
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
|
||||
// return res;
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
|
||||
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
|
||||
#pragma unroll
|
||||
for (int i = 0; i < size(dst); i++){
|
||||
dst(i) = Allreduce<4>::run(src(i), op);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
|
||||
__device__ inline void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
|
||||
thread_reduce_<zero_init>(tensor, summary, op);
|
||||
quad_allreduce_(summary, summary, op);
|
||||
}
|
||||
|
||||
template<int COLS> struct ReadType {};
|
||||
template<> struct ReadType<4> { using T = float;};
|
||||
template<> struct ReadType<8> { using T = float2;};
|
||||
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
|
||||
MaxOp<float> max_op;
|
||||
reduce_<zero_init>(tensor, max, max_op);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
__device__ inline void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
|
||||
SumOp<float> sum_op;
|
||||
reduce_(tensor, sum, sum_op);
|
||||
}
|
||||
|
||||
template <typename Cta_tile, typename Kernel_traits>
|
||||
struct Smem_tile_reduce {
|
||||
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
|
||||
|
||||
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
|
||||
static constexpr int COLS = WARPS_N;
|
||||
static_assert(COLS == 4 || COLS == 8);
|
||||
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
|
||||
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
|
||||
static constexpr int ELTS_PER_TILE = ROWS * COLS;
|
||||
|
||||
static constexpr int THREADS_PER_GROUP = Kernel_traits::Gmem_tile_o::THREADS_PER_ROW;
|
||||
// TD [2022-05-02]: No longer true if head_dim != 64
|
||||
// static_assert(THREADS_PER_GROUP == 16); // DEBUG
|
||||
static constexpr int ROWS_PER_WARP = 32 / THREADS_PER_GROUP;
|
||||
static constexpr int LOOPS = Kernel_traits::Gmem_tile_o::LOOPS;
|
||||
static_assert(LOOPS == 1);
|
||||
|
||||
using read_t = typename ReadType<COLS>::T;
|
||||
|
||||
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
|
||||
|
||||
int lane = tidx % 32;
|
||||
int warp = tidx / 32;
|
||||
|
||||
int warp_m = warp % WARPS_M;
|
||||
int warp_n = warp / WARPS_M;
|
||||
|
||||
qid_ = lane % 4;
|
||||
int qp = lane / 4;
|
||||
|
||||
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
|
||||
// This won't affect reading as we assume commutative reduction ops.
|
||||
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
|
||||
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
|
||||
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
|
||||
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
|
||||
if( qid_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * WARPS_N;
|
||||
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
|
||||
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
|
||||
// Apply the exp to all the elements.
|
||||
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
// If we don't have float around M_LOG2E the multiplication is done in fp64.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
|
||||
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
|
||||
}
|
||||
}
|
||||
|
||||
int qid_;
|
||||
float *smem_write_;
|
||||
read_t *smem_read_;
|
||||
read_t *smem_read_row_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax_base {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
|
||||
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
|
||||
// The number of elements that we are going to store per row.
|
||||
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
|
||||
// The number of rows.
|
||||
static constexpr int ROWS = Cta_tile::M * GROUPS;
|
||||
// The total number of elements.
|
||||
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx)
|
||||
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
|
||||
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
|
||||
|
||||
// Move to the 1st mask loaded by the thread+ tidx;
|
||||
// packed_mask_ptr_ += bidb * params.packed_mask_stride_in_bytes + tidx * sizeof(uint32_t);
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// Decompose the warp index into M and N.
|
||||
int warp_m = warp % Cta_tile::WARPS_M;
|
||||
int warp_n = warp / Cta_tile::WARPS_M;
|
||||
|
||||
// Decompose the warp-n index into group/position-inside-the-group.
|
||||
int warp_g = warp_n / ELEMENTS_PER_ROW;
|
||||
int warp_i = warp_n % ELEMENTS_PER_ROW;
|
||||
|
||||
// The location written by the threads.
|
||||
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
|
||||
int write_col = warp_i;
|
||||
|
||||
// Assemble the write pointer.
|
||||
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
|
||||
|
||||
// Assemble the read pointer.
|
||||
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
|
||||
}
|
||||
|
||||
template<bool zero=false, typename Mask>
|
||||
inline __device__ void apply_mask(const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < 2; ++ii ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
#pragma unroll
|
||||
for( int jj = 0; jj < 4; ++jj ) {
|
||||
if( !mask.is_valid(mi, ni, ii, jj) ) {
|
||||
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool max_in_base2=false, bool elt_in_base2=false>
|
||||
inline __device__ void apply_exp(const float (&max)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
const float max_base2 = max_in_base2 ? max[mi] : max[mi] * kLog2e;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
// elt_[mi][ni] = apply_exp_(elt_[mi][ni], max[mi]);
|
||||
elt_[mi][ni] = apply_exp2_(elt_in_base2 ? elt_[mi][ni] : elt_[mi][ni] * kLog2e,
|
||||
max_base2);
|
||||
}
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool scale_max=true>
|
||||
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
|
||||
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
|
||||
const float scale = scale_ * M_LOG2E;
|
||||
// Apply the exp to all the elements.
|
||||
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
MaxOp<float> max_op;
|
||||
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
for (int ni = 1; ni < size<1>(tensor); ni++) {
|
||||
max(mi) = max_op(max(mi), tensor(mi, ni));
|
||||
}
|
||||
max(mi) = Allreduce<4>::run(max(mi), max_op);
|
||||
// If max is -inf, then all elements must have been -inf (possibly due to masking).
|
||||
// We don't want (-inf - (-inf)) since that would give NaN.
|
||||
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
|
||||
sum(mi) = 0;
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < size<1>(tensor); ++ni) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
const float max_scaled = max[mi] * max_scale;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
|
||||
}
|
||||
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
|
||||
sum(mi) += tensor(mi, ni);
|
||||
}
|
||||
SumOp<float> sum_op;
|
||||
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool max_in_base2=false>
|
||||
inline __device__ void apply_exp_col(const float (&max)[MMAS_N * 4]) {
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k,
|
||||
const uint32_t col_idx_offset_ = 0) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
constexpr float kLog2e = M_LOG2E;
|
||||
const float max_base2 = max_in_base2 ? max[ni] : max[ni] * kLog2e;
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
|
||||
}
|
||||
}
|
||||
}
|
||||
// inline __device__ void apply_exp_col(const float (&max)[MMAS_N]) {
|
||||
// constexpr float kLog2e = M_LOG2E;
|
||||
// #pragma unroll
|
||||
// for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
// float max_base2 = max_in_base2 ? max[ni / 4] : max[ni / 4] * kLog2e;
|
||||
// max_base2 = __shfl_sync(0xffffffff, max_base2, (ni % 4) * 8 + threadIdx.x % 8);
|
||||
// #pragma unroll
|
||||
// for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
// elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * kLog2e, max_base2);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni++ ) {
|
||||
uint16_t tmp[8];
|
||||
// fmha::uint4_to_ushort8(ph(), tmp);
|
||||
uint4 tmp_32 = ph();
|
||||
fmha::uint4_to_ushort8(tmp_32, tmp);
|
||||
// if ((threadIdx.x % 32 == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("tidx = %d, ni = %d, ph Philox: %u, %u, %u, %u\n", threadIdx.x, ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
|
||||
// }
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
if (col_idx >= max_seqlen_k) {
|
||||
// Without the "make_coord" we get wrong results
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
tensor(mi, make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t,
|
||||
unsigned long long philox_subsequence) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
static_assert(MMAS_M == 1); // We're assuming 16x16 blocks.
|
||||
template <typename Engine, typename Layout>
|
||||
inline __device__ void apply_mask_causal(Tensor<Engine, Layout> &tensor, const uint32_t col_idx_offset_,
|
||||
const uint32_t max_seqlen_k, const uint32_t row_idx_offset_,
|
||||
const uint32_t warp_row_stride) {
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout::rank == 2, "Only support 2D Tensor");
|
||||
const uint32_t lane_id = threadIdx.x % 32;
|
||||
// const uint32_t row_idx_offset = row_idx_offset_ + lane_id / 4;
|
||||
const uint32_t row_idx_offset = row_idx_offset_;
|
||||
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
|
||||
const uint32_t row_idx_base = row_idx_offset + mi * warp_row_stride;
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
for (int i = 0; i < size<0, 0>(tensor); ++i) {
|
||||
const uint32_t row_idx = row_idx_base + i * 8;
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, row_idx + 1);
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni++ ) {
|
||||
uint16_t tmp[8];
|
||||
// fmha::uint4_to_ushort8(ph(), tmp);
|
||||
fmha::uint4_to_ushort8(ph(philox_subsequence + ni * Cta_tile::WARPS_N), tmp);
|
||||
// uint4 tmp_32 = ph(philox_subsequence + ni * Cta_tile::WARPS_N);
|
||||
// fmha::uint4_to_ushort8(tmp_32, tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp_32.x, tmp_32.y, tmp_32.z, tmp_32.w);
|
||||
// }
|
||||
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
|
||||
const uint32_t col_idx_base = col_idx_offset + nj * 8;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
for (int j = 0; j < size<1, 0>(tensor); ++j) {
|
||||
const uint32_t col_idx = col_idx_base + j;
|
||||
if (col_idx >= col_idx_limit) {
|
||||
tensor(make_coord(i, mi), make_coord(j, nj)) = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("mi = %d, i = %d, row_idx = %d, max_seqlen_k = %d\n", mi, i, row_idx, max_seqlen_k);
|
||||
// print(tensor(make_coord(i, mi), _));
|
||||
// // print(tensor(_, j + nj * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
|
||||
inline __device__ void apply_mask_causal_w_idx(
|
||||
Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &idx_rowcol,
|
||||
const uint32_t col_idx_offset_, const uint32_t max_seqlen_k, const uint32_t row_idx_offset_)
|
||||
{
|
||||
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
|
||||
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
|
||||
static_assert(Layout1::rank == 2, "Only support 2D Tensor");
|
||||
CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol));
|
||||
CUTE_STATIC_ASSERT_V(size<1>(tensor) == size<1>(idx_rowcol));
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < size<0>(tensor); ++mi) {
|
||||
const uint32_t col_idx_limit = std::min(max_seqlen_k, 1 + row_idx_offset_ + get<0>(idx_rowcol(mi, 0)));
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
static_assert(MMAS_N % 2 == 0);
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
|
||||
uint16_t tmp[8];
|
||||
fmha::uint4_to_ushort8(ph0(), tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
|
||||
// }
|
||||
for (int ni = 0; ni < size<1, 1>(tensor); ++ni) {
|
||||
if (col_idx_offset_ + get<1>(idx_rowcol(0, ni)) >= col_idx_limit) {
|
||||
tensor(mi, ni) = -INFINITY;
|
||||
}
|
||||
}
|
||||
// if (cute::thread0()) {
|
||||
// printf("ni = %d, j = %d, col_idx = %d, max_seqlen_k = %d\n", ni, j, col_idx, max_seqlen_k);
|
||||
// print(tensor(_, make_coord(j, ni)));
|
||||
// // print(tensor(_, j + ni * size<1, 0>(tensor)));
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false, typename Engine, typename Layout>
|
||||
inline __device__ void apply_dropout(Tensor<Engine, Layout> &tensor, uint8_t p_dropout_in_uint8_t,
|
||||
unsigned long long seed, unsigned long long offset,
|
||||
uint32_t block_row_start, uint32_t block_col_start,
|
||||
uint32_t block_row_stride) {
|
||||
// tensor has shape (8, MMA_M, MMA_N / 2)
|
||||
using T = typename Engine::value_type;
|
||||
auto encode_dropout = [](bool keep, T val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : T(0));
|
||||
};
|
||||
static_assert(decltype(size<2>(tensor))::value % 2 == 0);
|
||||
const uint16_t p_dropout_8bit_in_uint16_t = uint16_t(p_dropout_in_uint8_t);
|
||||
const uint32_t p_dropout_8bit_in_uint32_t = (uint32_t(p_dropout_8bit_in_uint16_t) << 16) | uint32_t(p_dropout_8bit_in_uint16_t);
|
||||
// if (cute::thread0()) { printf("threshold2 = 0x%x\n", p_dropout_8bit_in_uint32_t); }
|
||||
#pragma unroll
|
||||
for (int m = 0; m < size<1>(tensor); ++m, block_row_start += block_row_stride) {
|
||||
uint2 rowcol = make_uint2(block_row_start, block_col_start);
|
||||
#pragma unroll
|
||||
for (int n = 0; n < size<2>(tensor) / 2; ++n, ++rowcol.y) {
|
||||
// if (cute::thread(32, 0)) { printf("m = %d, n = %d, row = %d, col = %d\n", m, n, int(rowcol.x), int(rowcol.y));}
|
||||
uint4 random_uint4 = pytorch_flash::philox(seed, reinterpret_cast<unsigned long long&>(rowcol), offset);
|
||||
// if (cute::thread0()) { printf("philox = %u, %d, %d, %d\n", random_uint4.x, random_uint4.y, random_uint4.z, random_uint4.w);}
|
||||
uint8_t (&rnd_8)[16] = reinterpret_cast<uint8_t (&)[16]>(random_uint4);
|
||||
// Special implementation for 16-bit types: we duplicate the threshold to the
|
||||
// low and high 16 bits of a 32-bit value, then use the f16x2 comparison instruction
|
||||
// to get a mask. The low 16 bits of the mask will be either 0xffff or 0x0000,
|
||||
// and the high 16 bits will be either 0xffff or 0x0000, depending on whether
|
||||
// the random value is less than the threshold.
|
||||
// We then do a bit-wise AND between the mask and the original value (in 32-bit).
|
||||
// We're exploiting the fact that floating point comparison is equivalent to integer
|
||||
// comparison, since we're comparing unsigned integers whose top 8-bits are zero.
|
||||
if (!encode_dropout_in_sign_bit
|
||||
&& (std::is_same<T, cutlass::half_t>::value || std::is_same<T, cutlass::bfloat16_t>::value)) {
|
||||
uint16_t rnd_16[16];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
for (int i = 0; i < 16; i++) { rnd_16[i] = uint16_t(rnd_8[i]); }
|
||||
uint32_t (&rnd_32)[8] = reinterpret_cast<uint32_t (&)[8]>(rnd_16);
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 2; j++) {
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("random = 0x%x, 0x%x, 0x%x, 0x%x\n", rnd_32[j * 4 + 0], rnd_32[j * 4 + 1], rnd_32[j * 4 + 2], rnd_32[j * 4 + 3]); }
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
for (int i = 0; i < 4; i++) {
|
||||
uint32_t mask;
|
||||
asm volatile("set.le.u32.f16x2 %0, %1, %2;\n" : "=r"(mask) : "r"(rnd_32[j * 4 + i]), "r"(p_dropout_8bit_in_uint32_t));
|
||||
tensor_uint32(i) &= mask;
|
||||
}
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
fmha::uint4_to_ushort8(ph1(), tmp);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, tmp.x, tmp.y, tmp.z, tmp.w);
|
||||
// }
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
for (int j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
|
||||
encode_dropout(tmp[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
|
||||
for (int i = 0; i < 8; i++) {
|
||||
tensor(i, m, n * 2 + j) = encode_dropout(rnd_8[j * 8 + i] <= p_dropout_in_uint8_t, tensor(i, m, n * 2 + j));
|
||||
}
|
||||
Tensor tensor_uint32 = recast<uint32_t>(tensor(_, m, n * 2 + j));
|
||||
// if (cute::thread0()) { printf("tensor_uint32 = 0x%x, 0x%x, 0x%x, 0x%x\n", tensor_uint32(0), tensor_uint32(1), tensor_uint32(2), tensor_uint32(3)); }
|
||||
}
|
||||
}
|
||||
// // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// // printf("n = %d, ph Philox: %u, %u, %u, %u\n", n, rnd_8.x, rnd_8.y, rnd_8.z, rnd_8.w);
|
||||
// // }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale all the elements.
|
||||
inline __device__ void scale(const float (&sum)[MMAS_M * 2]) {
|
||||
// Precompute the inverse sum to normalize. Without -use_fast_math, it makes a huge deal.
|
||||
float inv_sum[MMAS_M * 2];
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
inv_sum[mi] = (sum[mi] == 0.f || sum[mi] != sum[mi]) ? 1.f : 1.f / sum[mi];
|
||||
}
|
||||
|
||||
// Update the values.
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] *= inv_sum[mi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract all elements by dp_sum
|
||||
inline __device__ void subtract_dp_sum(const float (&dp_sum)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] -= dp_sum[mi];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The pointer to the mask.
|
||||
const char *packed_mask_ptr_;
|
||||
// Shared memory for the CTA-wide reduction.
|
||||
float *smem_, *smem_write_, *smem_read_;
|
||||
// The current thread index.
|
||||
int tidx_;
|
||||
// The elements.
|
||||
float elt_[MMAS_M * 2][MMAS_N * 4];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
|
||||
// The base class.
|
||||
using Base = Softmax_base<Cta_tile, Kernel_traits>;
|
||||
// The fragment.
|
||||
using Fragment_a = fmha::Fragment_a<fmha::Row>;
|
||||
|
||||
static_assert(Fragment_a::NUM_REGS == 4);
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
// The MMAs.
|
||||
static constexpr int MMAS_M = Base::MMAS_M;
|
||||
static constexpr int MMAS_N = Base::MMAS_N;
|
||||
|
||||
// The accumulators.
|
||||
using Accumulator = fmha::Fragment_accumulator;
|
||||
using Accumulator_out = Fragment<uint16_t, 8>;
|
||||
static_assert(Accumulator_out::NUM_REGS == 4);
|
||||
|
||||
static_assert(std::is_same<Accumulator::Data_type, float>::value);
|
||||
|
||||
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
|
||||
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax(const Params ¶ms, void *smem, int tidx)
|
||||
: Base(params, smem, tidx)
|
||||
, params_scale_bmm1_(params.scale_bmm1)
|
||||
, smem_sum_(static_cast<float*>(smem), tidx)
|
||||
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
// Pack the data to a fragment for the next GEMM.
|
||||
template<typename elem_type=__half, int K, int M>
|
||||
inline __device__ void pack(Fragment_a (&dst)[K][M]) const {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < K; ++ki ) {
|
||||
|
||||
// 1st row - 4 elements per row.
|
||||
float tmp_00 = this->elt_[2 * mi + 0][4 * ki + 0];
|
||||
float tmp_01 = this->elt_[2 * mi + 0][4 * ki + 1];
|
||||
float tmp_02 = this->elt_[2 * mi + 0][4 * ki + 2];
|
||||
float tmp_03 = this->elt_[2 * mi + 0][4 * ki + 3];
|
||||
|
||||
// 2nd row - 4 elements per row.
|
||||
float tmp_10 = this->elt_[2 * mi + 1][4 * ki + 0];
|
||||
float tmp_11 = this->elt_[2 * mi + 1][4 * ki + 1];
|
||||
float tmp_12 = this->elt_[2 * mi + 1][4 * ki + 2];
|
||||
float tmp_13 = this->elt_[2 * mi + 1][4 * ki + 3];
|
||||
|
||||
// Pack to 4 registers.
|
||||
dst[ki][mi].reg(0) = fmha::float2_pack<elem_type>(tmp_00, tmp_01);
|
||||
dst[ki][mi].reg(1) = fmha::float2_pack<elem_type>(tmp_10, tmp_11);
|
||||
dst[ki][mi].reg(2) = fmha::float2_pack<elem_type>(tmp_02, tmp_03);
|
||||
dst[ki][mi].reg(3) = fmha::float2_pack<elem_type>(tmp_12, tmp_13);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale FP32 fragments
|
||||
inline __device__ void unpack(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
|
||||
const float scalef = reinterpret_cast<const float &>(this->params_scale_bmm1_);
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
// 1st row - 4 elements per row.
|
||||
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4) * scalef;
|
||||
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5) * scalef;
|
||||
// 2nd row - 4 elements per row.
|
||||
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6) * scalef;
|
||||
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7) * scalef;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Scale FP32 fragments
|
||||
inline __device__ void unpack_noscale(const Accumulator (&acc)[MMAS_M][MMAS_N]) {
|
||||
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
// 1st row - 4 elements per row.
|
||||
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi][ni].elt(0);
|
||||
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi][ni].elt(1);
|
||||
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi][ni].elt(4);
|
||||
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi][ni].elt(5);
|
||||
// 2nd row - 4 elements per row.
|
||||
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi][ni].elt(2);
|
||||
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi][ni].elt(3);
|
||||
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi][ni].elt(6);
|
||||
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi][ni].elt(7);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
|
||||
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
|
||||
#pragma unroll
|
||||
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
|
||||
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
|
||||
thread_reduce_<zero_init>(frag, op);
|
||||
quad_reduce(frag, frag, op);
|
||||
smem_red.store(frag);
|
||||
__syncthreads();
|
||||
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
|
||||
smem_red.load(tmp);
|
||||
quad_allreduce(frag, tmp, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
|
||||
MaxOp<float> max;
|
||||
reduce_<zero_init>(frag, max, smem_max_);
|
||||
}
|
||||
|
||||
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
reduce_(frag, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
thread_reduce_<zero_init>(frag, sum);
|
||||
quad_reduce(frag, frag, sum);
|
||||
smem_sum_.store(frag);
|
||||
}
|
||||
|
||||
template<int NROWS, typename Operator>
|
||||
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS],
|
||||
Operator &op, Smem_tile_red & smem_red) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NROWS; ii++) {
|
||||
typename Smem_tile_red::read_t tmp[MMAS_M];
|
||||
smem_red.load_row(tmp, rows[ii]);
|
||||
quad_allreduce(frag[ii], tmp, op);
|
||||
}
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
SumOp<float> sum;
|
||||
reduce_after_sync_(frag, rows, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
MaxOp<float> max;
|
||||
reduce_after_sync_(frag, rows, max, smem_max_);
|
||||
}
|
||||
|
||||
const uint32_t params_scale_bmm1_;
|
||||
Smem_tile_red smem_max_;
|
||||
Smem_tile_red smem_sum_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
||||
} // namespace pytorch_flash
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// and https://github.com/facebookresearch/xformers/blob/main/xformers/csrc/attention/cuda/fmha/gemm_kernel_utils.h#L8
|
||||
|
||||
#pragma once
|
||||
|
||||
@ -10,31 +10,57 @@
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, ([&] {
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// }));
|
||||
/// });
|
||||
/// ```
|
||||
/// We need "({" and "})" to make sure that the code is a single argument being passed to the macro.
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, F) \
|
||||
{ \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
F(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr static bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr static bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// modified from BOOL_SWITCH
|
||||
// because MSVC cannot handle std::conditional with constexpr variable
|
||||
#define FP16_SWITCH(COND, F) \
|
||||
{ \
|
||||
if (COND) { \
|
||||
using elem_type = __nv_bfloat16; \
|
||||
F(); \
|
||||
} else { \
|
||||
using elem_type = __half; \
|
||||
F(); \
|
||||
} \
|
||||
}
|
||||
#define FP16_SWITCH(COND, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
using elem_type = cutlass::half_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
using elem_type = cutlass::bfloat16_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define FWD_HEADDIM_SWITCH(HEADDIM, ...) \
|
||||
[&] { \
|
||||
if (HEADDIM <= 32) { \
|
||||
constexpr static int kHeadDim = 32; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 64) { \
|
||||
constexpr static int kHeadDim = 64; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 96) { \
|
||||
constexpr static int kHeadDim = 96; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 128) { \
|
||||
constexpr static int kHeadDim = 128; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 160) { \
|
||||
constexpr static int kHeadDim = 160; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 192) { \
|
||||
constexpr static int kHeadDim = 192; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 224) { \
|
||||
constexpr static int kHeadDim = 224; \
|
||||
return __VA_ARGS__(); \
|
||||
} else if (HEADDIM <= 256) { \
|
||||
constexpr static int kHeadDim = 256; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -41,75 +41,15 @@
|
||||
|
||||
namespace sdp {
|
||||
namespace {
|
||||
// flash_attention V2 is universally faster than efficient_attention and Math
|
||||
std::array<SDPBackend, num_backends> priority_order(sdp_params params) {
|
||||
constexpr std::array<SDPBackend, num_backends> default_order{
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::math};
|
||||
|
||||
constexpr std::array<SDPBackend, num_backends> efficient_first{
|
||||
SDPBackend::efficient_attention,
|
||||
SDPBackend::flash_attention,
|
||||
SDPBackend::math};
|
||||
// Logic is taken from xformers
|
||||
// FlashAttention parallelizes across "batch_size * num_heads"
|
||||
// MemEff parallelizes across "batch_size * num_heads * num_queries" and can
|
||||
// be more efficient. batch_size, q_len, num_heads, k = inp.query.shape
|
||||
|
||||
if (has_for_nested_inputs(params)) {
|
||||
return efficient_first;
|
||||
}
|
||||
if (params.query.dim() != 4) {
|
||||
return default_order;
|
||||
}
|
||||
const auto batch_size{params.query.sym_size(0)},
|
||||
num_heads{params.query.sym_size(1)},
|
||||
query_lengths{params.query.sym_size(2)},
|
||||
head_dim{params.query.sym_size(3)};
|
||||
if (batch_size > 0) {
|
||||
const auto threads_flash = batch_size * num_heads;
|
||||
const auto threads_cutlass =
|
||||
threads_flash * (query_lengths / c10::SymInt(64));
|
||||
bool more_threads_cutlass = (threads_cutlass / 2) >= threads_flash;
|
||||
bool small_threads_flash = threads_flash < 60;
|
||||
bool large_head_dim = head_dim.max(params.key.sym_size(3)) == 128;
|
||||
|
||||
// The training heuristic is taken from
|
||||
// https://github.com/pytorch/pytorch/pull/99644 Revisit when updated
|
||||
// cutlass kernel is upstreamed.
|
||||
if (input_requires_grad(params)) {
|
||||
if (6 * threads_flash > query_lengths)
|
||||
return efficient_first;
|
||||
} else if ((small_threads_flash && more_threads_cutlass) || large_head_dim)
|
||||
return efficient_first;
|
||||
}
|
||||
return default_order;
|
||||
}
|
||||
|
||||
bool check_head_dim_size(sdp_params params, bool debug) {
|
||||
const auto query_size_last = params.query.sym_size(-1);
|
||||
const auto key_size_last = params.key.sym_size(-1);
|
||||
const auto value_size_last = params.value.sym_size(-1);
|
||||
if (!(query_size_last == key_size_last &&
|
||||
query_size_last == value_size_last && query_size_last % 8 == 0 &&
|
||||
query_size_last <= 128 && value_size_last % 8 == 0 &&
|
||||
value_size_last <= 128)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 128.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.sym_size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.sym_size(-1),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool use_tensor_cores(sdp_params params, cudaDeviceProp* dprops, bool is_half) {
|
||||
if (dprops->major >= 8) {
|
||||
return true;
|
||||
@ -135,6 +75,48 @@ int64_t minimum_gemm_alignment(sdp_params params) {
|
||||
return matmul_alignment_mn;
|
||||
}
|
||||
|
||||
bool check_head_dim_size_flash(sdp_params params, bool debug) {
|
||||
// All head_dim sizes must be equal and less than 256
|
||||
const auto max_size = c10::SymInt(256);
|
||||
const auto query_size_last = params.query.sym_size(-1);
|
||||
const auto key_size_last = params.key.sym_size(-1);
|
||||
const auto value_size_last = params.value.sym_size(-1);
|
||||
bool same_head_dim_size =
|
||||
query_size_last == key_size_last && query_size_last == value_size_last;
|
||||
if (has_for_nested_inputs(params)) {
|
||||
if (!(same_head_dim_size && (query_size_last % 8 == 0) &&
|
||||
(query_size_last <= max_size))) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"For NestedTensor inputs, Flash attention requires q,k,v to have the same last dimension and to be a multiple of 8 and less than or equal to 256.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
params.key.sym_size(-1),
|
||||
", Value.size(-1): ",
|
||||
params.value.sym_size(-1),
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!(same_head_dim_size && (query_size_last <= max_size))) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention requires q,k,v to have the same last dimension and to be less than or equal to 256.",
|
||||
" Got Query.size(-1): ",
|
||||
query_size_last,
|
||||
", Key.size(-1): ",
|
||||
key_size_last,
|
||||
", Value.size(-1): ",
|
||||
value_size_last,
|
||||
" instead.");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_head_dim_size_mem_efficient(sdp_params params, bool debug) {
|
||||
const auto query_size_last = params.query.sym_size(-1);
|
||||
const auto value_size_last = params.value.sym_size(-1);
|
||||
@ -186,15 +168,15 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
||||
return is_gte_lower_bound && is_lte_upper_bound;
|
||||
}
|
||||
|
||||
bool check_gpu_sm75_or_greater(sdp_params params, bool debug) {
|
||||
bool check_flash_attention_hardware_support(sdp_params params, bool debug) {
|
||||
// Check that the gpu is capable of running flash attention
|
||||
using sm75 = SMVersion<7, 5>;
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm75, sm90>(dprops)) {
|
||||
if (!check_sm_version<sm80, sm90>(dprops)) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention only supports gpu architectures in the range [sm75, sm90]. Attempting to run on a sm ",
|
||||
"Flash attention only supports gpu architectures in the range [sm80, sm90]. Attempting to run on a sm ",
|
||||
dprops->major,
|
||||
".",
|
||||
dprops->minor,
|
||||
@ -224,7 +206,7 @@ bool check_mem_efficient_hardware_support(sdp_params params, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool check_requires_grad_and_head_dim_gt64_and_sm_ge86_lt90(
|
||||
bool check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90(
|
||||
sdp_params params,
|
||||
bool debug) {
|
||||
// Flash Attention will raise an error in the backward pass if the head_dim
|
||||
@ -233,11 +215,11 @@ bool check_requires_grad_and_head_dim_gt64_and_sm_ge86_lt90(
|
||||
using sm89 = SMVersion<8, 9>;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm86_or_sm89 = check_sm_version<sm86, sm89>(dprops);
|
||||
bool is_head_dim_gt64 = params.query.sym_size(-1) > 64;
|
||||
if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt64) {
|
||||
bool is_head_dim_gt192 = params.query.sym_size(-1) > 192;
|
||||
if (input_requires_grad(params) && is_sm86_or_sm89 && is_head_dim_gt192) {
|
||||
if (debug) {
|
||||
TORCH_WARN(
|
||||
"Flash attention currently doesn't support training with head_dim greater than 64 on gpu architectures in the range[sm86, sm89].",
|
||||
"Flash attention currently doesn't support training with head_dim greater than 192 on gpu architectures in the range[sm86, sm89].",
|
||||
"Attempting to run with head_dim: ",
|
||||
params.query.sym_size(-1), " on a sm ", dprops->major, ".",
|
||||
dprops->minor, " gpu.");
|
||||
@ -249,7 +231,7 @@ bool check_requires_grad_and_head_dim_gt64_and_sm_ge86_lt90(
|
||||
|
||||
bool use_flash_attention(sdp_params params, bool debug) {
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
TORCH_CHECK(!debug, "Torch was not compiled with flash attention.");
|
||||
TORCH_WARN(!debug, "Torch was not compiled with flash attention.");
|
||||
return false;
|
||||
#endif
|
||||
|
||||
@ -260,12 +242,13 @@ bool use_flash_attention(sdp_params params, bool debug) {
|
||||
check_tensor_shapes,
|
||||
check_batch_size_and_num_heads,
|
||||
check_for_attn_mask,
|
||||
check_head_dim_size,
|
||||
check_gpu_sm75_or_greater,
|
||||
check_requires_grad_and_head_dim_gt64_and_sm_ge86_lt90,
|
||||
check_head_dim_size_flash,
|
||||
check_for_seq_len_0_nested_tensor,
|
||||
check_nonzero_sequence_lengths,
|
||||
check_last_dim_stride_equals_1);
|
||||
check_last_dim_stride_equals_1,
|
||||
check_flash_attention_hardware_support,
|
||||
check_requires_grad_and_head_dim_gt192_and_sm_ge86_lt90,
|
||||
check_for_seq_len_0_nested_tensor);
|
||||
for (auto& constraint : constraints) {
|
||||
if (!constraint(params, debug)) {
|
||||
return false;
|
||||
@ -285,7 +268,7 @@ bool use_flash_attention(sdp_params params, bool debug) {
|
||||
|
||||
bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
||||
#ifndef USE_MEM_EFF_ATTENTION
|
||||
TORCH_CHECK(!debug, "Torch was not compiled with memory efficient attention.");
|
||||
TORCH_WARN(!debug, "Torch was not compiled with memory efficient attention.");
|
||||
return false;
|
||||
#endif
|
||||
// Constraints specific to mem efficient attention
|
||||
|
||||
@ -81,6 +81,8 @@ SKIP = {
|
||||
"hf_Bert_large", # Error: RelaxedUnspecConstraint(L['input_ids'].size()[0]) - inferred constant (4)
|
||||
# takes too long, extreme slowdown (< .001)
|
||||
"maml",
|
||||
# Failing in eager mode
|
||||
"clip",
|
||||
}
|
||||
|
||||
SKIP_FOR_CPU = {
|
||||
|
||||
@ -1439,7 +1439,7 @@ aten_cuda_cu_source_list = [
|
||||
"aten/src/ATen/native/sparse/cuda/SparseBlasImpl.cpp",
|
||||
"aten/src/ATen/native/sparse/cuda/SparseBlasLegacy.cpp",
|
||||
"aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp",
|
||||
"aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp",
|
||||
"aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp",
|
||||
]
|
||||
|
||||
# Files using thrust::sort_by_key need to be linked last
|
||||
|
||||
@ -9,10 +9,7 @@ import torch.onnx.operators
|
||||
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
|
||||
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA,
|
||||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
|
||||
|
||||
class CutomizedCtxManager:
|
||||
@ -288,7 +285,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertTrue(same(ref, res))
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Can't run fused SDPA on this platform",
|
||||
)
|
||||
def test_autocast_sdpa(self):
|
||||
|
||||
@ -50,7 +50,7 @@ from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
SM80OrLater,
|
||||
TEST_CUDA,
|
||||
TEST_MULTIGPU,
|
||||
@ -3919,7 +3919,7 @@ def fn():
|
||||
self.assertEqual(cnts.op_count, 3)
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
"Can't run fused SDPA on this platform",
|
||||
)
|
||||
def test_parsing_sdpa(self):
|
||||
|
||||
@ -266,15 +266,10 @@ ALLOW_LIST = [
|
||||
("aten::_upsample_nearest_exact2d_backward", datetime.date(2022, 12, 15)),
|
||||
("aten::_scaled_dot_product_attention", datetime.date(2023, 8, 1)),
|
||||
("aten::_chunk_grad_outputs_efficient_attention", datetime.date(2023, 8, 1)),
|
||||
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 5, 15)),
|
||||
("aten::_scaled_dot_product_efficient_attention", datetime.date(2023, 8, 15)),
|
||||
("aten::_scaled_dot_product_efficient_attention_backward", datetime.date(2023, 8, 15)),
|
||||
("aten::_scaled_dot_product_flash_attention", datetime.date(2023, 9, 1)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 9, 1)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 9, 1)),
|
||||
("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)),
|
||||
("aten::_fused_sdp_choice", datetime.date(2023, 3, 15)),
|
||||
("aten::_flash_attention_forward", datetime.date(2023, 5, 15)),
|
||||
("aten::_flash_attention_backward", datetime.date(2023, 5, 15)),
|
||||
("aten::_efficient_attention_forward", datetime.date(2023, 7, 1)),
|
||||
("aten::_efficient_attention_backward", datetime.date(2023, 8, 1)),
|
||||
("mkldnn::_convolution_pointwise.binary", datetime.date(2022, 12, 15)),
|
||||
("prim::CudaFusionIvalGuard", datetime.date(2023, 2, 1)),
|
||||
("prim::CudaFusionGuard", datetime.date(2023, 2, 1)),
|
||||
|
||||
@ -10,7 +10,7 @@ from torch._dynamo.utils import counters
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA,
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
SM80OrLater,
|
||||
)
|
||||
from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm
|
||||
@ -237,8 +237,8 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
div = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))
|
||||
div = div.to(torch.float32)
|
||||
attn_weight = torch.softmax(div, dim=-1)
|
||||
# very small dropout to make sure test passes
|
||||
attn_weight = torch.dropout(attn_weight, 0.00001, True)
|
||||
# Set to False
|
||||
attn_weight = torch.dropout(attn_weight, 0.00000000001, True)
|
||||
attn_weight = attn_weight.to(torch.float16)
|
||||
return attn_weight @ v
|
||||
|
||||
@ -295,7 +295,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
div = div.to(torch.float32)
|
||||
attn_weight = torch.softmax(div, dim=-1)
|
||||
# very low dropout to make test pass
|
||||
attn_weight = torch.dropout(attn_weight, 0.9999, True)
|
||||
attn_weight = torch.dropout(attn_weight, 0.00000000001, True)
|
||||
attn_weight = attn_weight.to(torch.float16)
|
||||
return attn_weight @ v
|
||||
|
||||
@ -496,7 +496,7 @@ class TestSDPAPatternRewriterTemplate(TestCase):
|
||||
self._check_common(dot_prod_attention)
|
||||
|
||||
|
||||
if HAS_CUDA and PLATFORM_SUPPORTS_FUSED_SDPA:
|
||||
if HAS_CUDA and PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||
|
||||
class SDPAPatternRewriterCudaTests(TestSDPAPatternRewriterTemplate):
|
||||
device = "cuda"
|
||||
|
||||
@ -18,7 +18,7 @@ from torch._subclasses.fake_tensor import (
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.common_device_type import ops
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, OpDTypes
|
||||
from torch.testing._internal.common_cuda import SM80OrLater, PLATFORM_SUPPORTS_FUSED_SDPA
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch.testing import FileCheck
|
||||
@ -1011,7 +1011,7 @@ class FakeTensorOperatorInvariants(TestCase):
|
||||
self.assertEqual(ref.size(), meta_out.size())
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
def test_flash_attention(self):
|
||||
class Repro(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_WITH_TORCHDYNAMO
|
||||
from torch.testing._internal.common_cuda import SM80OrLater, PLATFORM_SUPPORTS_FUSED_SDPA
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
import torch.utils.flop_counter
|
||||
import torch.nn.functional as F
|
||||
import unittest
|
||||
@ -161,7 +161,7 @@ class TestFlopCounter(TestCase):
|
||||
T(4, 5).cos()
|
||||
|
||||
@unittest.skipIf(not HAS_CUDA, "CUDA not available")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
def test_sdpa(self):
|
||||
batch_size = 4
|
||||
n_heads = 8
|
||||
|
||||
@ -30,7 +30,11 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
|
||||
from torch.testing._internal.common_methods_invocations import wrapper_set_seed
|
||||
from torch.testing._internal.common_cuda import SM75OrLater, SM80OrLater, PLATFORM_SUPPORTS_FUSED_SDPA
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM80OrLater, PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION
|
||||
)
|
||||
|
||||
if TEST_FAIRSEQ:
|
||||
import fairseq.models.transformer as fairseq_transformer
|
||||
@ -1186,7 +1190,7 @@ class TestTransformers(NNTestCase):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
@unittest.skipIf(
|
||||
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
|
||||
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not supposrt fused SDPA or pre-SM80 hardware"
|
||||
)
|
||||
def test_is_causal_gpu(self):
|
||||
device = 'cuda'
|
||||
@ -1207,9 +1211,9 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
_do_cuda_non_default_stream = True
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86or89Device,
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM86or89Device,
|
||||
"Does not support fused SDPA or not SM86+ hardware")
|
||||
@parametrize("head_dim", [72, 96, 128])
|
||||
@parametrize("head_dim", [193, 204, 256])
|
||||
def test_flash_backward_failure_sm86plus(self, device, head_dim: int):
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype)
|
||||
@ -1247,11 +1251,11 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_dim_3(self, device, kernel: SDPBackend):
|
||||
@ -1267,11 +1271,11 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_broadcast(self, device, kernel: SDPBackend):
|
||||
@ -1287,9 +1291,9 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
SM80OrLater else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_invalid_sequence_lengths(self, device, kernel: SDPBackend):
|
||||
with sdp_kernel(**backend_map[kernel]):
|
||||
# Passing in a q,k,v with 0 length sequences will error
|
||||
@ -1303,9 +1307,9 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
SM80OrLater else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_invalid_last_dim_stride(self, device, kernel: SDPBackend):
|
||||
with sdp_kernel(**backend_map[kernel]):
|
||||
# Passing in a q,k,v with 0 length sequences will error
|
||||
@ -1319,24 +1323,24 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not flash_attention fused scaled dot product attention")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_invalid_fused_inputs_head_dim(self, device, kernel: SDPBackend):
|
||||
with sdp_kernel(**backend_map[kernel]):
|
||||
# The embed dim per head is not divisible by 8 for flash attention
|
||||
dtype = torch.float16
|
||||
make_tensor = partial(rand_sdpa_tensor, type="dense", device=device, dtype=dtype)
|
||||
size = (2, 2, 3, 9)
|
||||
size = (2, 2, 3, 9) if kernel == SDPBackend.EFFICIENT_ATTENTION else (2, 2, 3, 257)
|
||||
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention")
|
||||
@parametrize(
|
||||
"kernel",
|
||||
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
|
||||
if SM80OrLater
|
||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||
else [SDPBackend.EFFICIENT_ATTENTION],
|
||||
)
|
||||
def test_invalid_fused_inputs_invalid_dtype(self, device, kernel: SDPBackend):
|
||||
@ -1349,7 +1353,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused scaled dot product attention")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION])
|
||||
def test_invalid_fused_inputs_attn_mask_present(self, device, kernel: SDPBackend):
|
||||
with sdp_kernel(**backend_map[kernel]):
|
||||
@ -1363,7 +1367,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, mask, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
|
||||
def test_unaligned_tensors(self, device):
|
||||
# The alignment is depdent on arch so we specifiy SM80OrLater
|
||||
dtype = torch.float16
|
||||
@ -1375,7 +1379,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support fused SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support fused SDPA or pre-SM80 hardware")
|
||||
def test_flash_fail_fp32(self, device):
|
||||
dtype = torch.float
|
||||
shape = (16, 16, 32, 32)
|
||||
@ -1387,7 +1391,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
def test_flash_autocast_fp32_float16(self, device):
|
||||
dtype = torch.float
|
||||
shape = (16, 16, 32, 32)
|
||||
@ -1399,7 +1403,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
q, k, v, None, 0.0, False)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
def test_flash_autocast_fp32_bfloat16(self, device):
|
||||
dtype = torch.float
|
||||
shape = (16, 16, 32, 32)
|
||||
@ -1441,7 +1445,7 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value))
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
def test_fused_kernels_nested_broadcasting_error_cases(self, device):
|
||||
# one of k,v needs to be broadcasted and other has non consistent seq_len dim
|
||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
||||
@ -1463,7 +1467,21 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Fused SDPA was not built for this system")
|
||||
def test_nested_fails_on_padding_head_dim(self, device):
|
||||
dtype = torch.bfloat16
|
||||
seq_len_list = [2, 4, 5, 6, 7]
|
||||
shape = (5, seq_len_list, 8, 57)
|
||||
make_tensor = partial(rand_sdpa_tensor, shape=shape, type="nested", device=device, dtype=dtype)
|
||||
q, k, v = make_tensor(), make_tensor(), make_tensor()
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
||||
with self.assertWarnsRegex(UserWarning, "For NestedTensor inputs, Flash attention requires"):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION or not isSM5xDevice, "Does not support fused SDPA or not SM50 hardware")
|
||||
def test_mem_efficient_fail_bfloat16_sm50(self, device):
|
||||
dtype = torch.bfloat16
|
||||
shape = (16, 16, 32, 32)
|
||||
@ -1474,6 +1492,50 @@ class TestSDPAFailureModes(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False))
|
||||
|
||||
def _get_block_size(device, head_dim, is_causal):
|
||||
# This should match the block sizes in the CUDA kernel
|
||||
# Mask is only interesting when we are setting dropout
|
||||
is_dropout = True
|
||||
assert head_dim <= 256
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100)
|
||||
is_sm80 = major == 8 and minor == 0
|
||||
is_sm90 = major == 9 and minor == 0
|
||||
if head_dim <= 32:
|
||||
return 128, 128
|
||||
if head_dim <= 64:
|
||||
return (128, 128) if not is_dropout else (128, 64)
|
||||
elif head_dim <= 96:
|
||||
return (64, 64) if (is_sm8x and is_causal) else (128, 64)
|
||||
elif head_dim <= 128:
|
||||
if is_sm8x:
|
||||
return (64, 64) if (not is_dropout and is_causal) else (128, 32)
|
||||
else:
|
||||
return 128, (64 if not is_dropout else 32)
|
||||
elif head_dim <= 160:
|
||||
if is_sm8x:
|
||||
return (128, 64) if not is_causal else (64, 64)
|
||||
else:
|
||||
return 128, 32
|
||||
elif head_dim <= 192:
|
||||
return (128, 64) if not is_dropout else (64, 64)
|
||||
elif head_dim <= 224:
|
||||
return (128, 64) if (is_sm80 or is_sm90) else (64, 64)
|
||||
elif head_dim <= 256:
|
||||
return (128, 64) if is_sm80 else (64, 64)
|
||||
|
||||
|
||||
def pad_last_dim(input_tensor, alignment_size, slice: bool = False):
|
||||
last_dim_size = input_tensor.size(-1)
|
||||
if (last_dim_size % alignment_size == 0):
|
||||
return input_tensor, last_dim_size
|
||||
pad_count = alignment_size - (last_dim_size % alignment_size)
|
||||
padded_tensor = F.pad(input_tensor, (0, pad_count))
|
||||
if slice:
|
||||
return padded_tensor[..., :last_dim_size], last_dim_size
|
||||
return padded_tensor, last_dim_size
|
||||
|
||||
|
||||
class TestSDPA(NNTestCase):
|
||||
""" Used to test generic functionality of scaled_dot_product_attention
|
||||
Summary:
|
||||
@ -1620,7 +1682,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
Quarks:
|
||||
There is some trickiness with this function. It's runtime behavior
|
||||
is dependent on the CUDA architecture you are testing it on. See
|
||||
`PLATFORM_SUPPORTS_FUSED_SDPA` at the top of the file.
|
||||
`PLATFORM_SUPPORTS_FUSED_ATTENTION` at the top of the file.
|
||||
Summary:
|
||||
Math: always supported
|
||||
FlashAttention: Supported on sm80 or newer hardware
|
||||
@ -1636,45 +1698,43 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
query_padding_mask: (batch_size, seqlen_q)
|
||||
key_padding_mask: (batch_size, seqlen_k)
|
||||
"""
|
||||
def _get_block_size(head_dim):
|
||||
assert head_dim % 8 == 0 and head_dim <= 128
|
||||
return 256 if head_dim <= 64 else 128
|
||||
S_flat = S.view(S.shape[0], S.shape[1], S.shape[2] * S.shape[3])
|
||||
seqlen_q, seqlen_k = S.shape[-2:]
|
||||
block_size = _get_block_size(head_dim)
|
||||
loop_steps = math.ceil(seqlen_k / block_size)
|
||||
b, h, seqlen_q, seqlen_k = S.shape
|
||||
warps_n = 4
|
||||
mmas_n = (seqlen_k // warps_n //
|
||||
16) if seqlen_k <= block_size else (block_size // warps_n // 16)
|
||||
blocksize_m, blocksize_n = _get_block_size(S.device, head_dim, causal)
|
||||
nblocks_m = (seqlen_q + blocksize_m - 1) // blocksize_m
|
||||
nblocks_n = (seqlen_k + blocksize_n - 1) // blocksize_n
|
||||
mmas_n = (blocksize_n + 16 - 1) // 16
|
||||
|
||||
S_converted = S_flat.view(S_flat.shape[0], S_flat.shape[1], loop_steps,
|
||||
seqlen_q // 16, mmas_n, warps_n, 8, 4, 2, 2, 2)
|
||||
S_converted = S_converted.permute(0, 1, 3, 8, 6, 2, 4, 5, 9, 7, 10)
|
||||
S_converted = S_converted.reshape(S_flat.shape[0],
|
||||
S_flat.shape[1], (seqlen_q // 16 * 2 * 8), (loop_steps * mmas_n * warps_n * 2 * 4 * 2))
|
||||
# Need to zero out things not in attention_mask in case S was initialized with random values
|
||||
# and some of those values aren't overwritten.
|
||||
seqlen_q_og = query_padding_mask.shape[-1]
|
||||
if seqlen_q_og < seqlen_q:
|
||||
query_padding_mask = F.pad(
|
||||
query_padding_mask, (0, seqlen_q - seqlen_q_og))
|
||||
else:
|
||||
query_padding_mask = query_padding_mask[:, :seqlen_q]
|
||||
q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
|
||||
S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
|
||||
seqlen_k_og = key_padding_mask.shape[-1]
|
||||
if seqlen_k_og < seqlen_k:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
|
||||
else:
|
||||
key_padding_mask = key_padding_mask[:, :seqlen_k]
|
||||
|
||||
k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1])
|
||||
S_converted = S_converted.masked_fill(k_mask_fill, 0.0)
|
||||
# Reshape S using PyTorch native functions
|
||||
S_flat = S.view(b, h, nblocks_m, blocksize_m, nblocks_n, blocksize_n)
|
||||
S_flat = S_flat.permute(0, 1, 2, 4, 3, 5)
|
||||
S_flat = S_flat.reshape(b, h, nblocks_m, nblocks_n, (blocksize_m * blocksize_n))
|
||||
S_converted = S_flat.view(b, h, nblocks_m, nblocks_n, mmas_n, -1, warps_n, 8, 4, 2, 2, 2)
|
||||
S_converted = S_converted.permute(0, 1, 2, 5, 6, 10, 7, 3, 4, 9, 8, 11)
|
||||
S_converted = S_converted.reshape(b, h, (nblocks_m * S_converted.size(3) *
|
||||
warps_n * 2 * 8), (nblocks_n * mmas_n * 2 * 4 * 2))
|
||||
|
||||
if causal:
|
||||
causal_mask = torch.triu(torch.ones(
|
||||
seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
|
||||
causal_mask = torch.triu(torch.ones(seqlen_q, seqlen_k, dtype=torch.bool, device=S.device), 1)
|
||||
S_converted.masked_fill_(causal_mask, 0.0)
|
||||
# Need to zero out things not in attention_mask in case S was initialized with random values
|
||||
# and some of those values aren't overwritten.
|
||||
seqlen_q_og = query_padding_mask.shape[-1] if query_padding_mask is not None else seqlen_q
|
||||
if query_padding_mask is not None:
|
||||
if seqlen_q_og < seqlen_q:
|
||||
query_padding_mask = F.pad(query_padding_mask, (0, seqlen_q - seqlen_q_og))
|
||||
else:
|
||||
query_padding_mask = query_padding_mask[:, :seqlen_q]
|
||||
q_mask_fill = ~query_padding_mask.view(query_padding_mask.shape[0], 1, query_padding_mask.shape[1], 1)
|
||||
S_converted = S_converted.masked_fill(q_mask_fill, 0.0)
|
||||
seqlen_k_og = key_padding_mask.shape[-1] if key_padding_mask is not None else seqlen_k
|
||||
if key_padding_mask is not None:
|
||||
if seqlen_k_og < seqlen_k:
|
||||
key_padding_mask = F.pad(key_padding_mask, (0, seqlen_k - seqlen_k_og))
|
||||
else:
|
||||
key_padding_mask = key_padding_mask[:, :seqlen_k]
|
||||
k_mask_fill = ~key_padding_mask.view(key_padding_mask.shape[0], 1, 1, key_padding_mask.shape[1])
|
||||
S_converted = S_converted.masked_fill(k_mask_fill, 0.0)
|
||||
if seqlen_q_og < seqlen_q:
|
||||
S_converted = S_converted[:, :, :seqlen_q_og, :]
|
||||
else:
|
||||
@ -1692,7 +1752,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
value_ref = value.clone().detach().to(dtype).requires_grad_(value.requires_grad)
|
||||
return query_ref, key_ref, value_ref
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("mask_dim", [1, 2, 3, 4])
|
||||
def test_mem_efficient_attetntion_mask_variants(self, device, mask_dim: List[int]):
|
||||
dtype = torch.float16
|
||||
@ -1715,7 +1775,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out = F.scaled_dot_product_attention(query, key, value, mask)
|
||||
out.sum().backward()
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("dtype", [torch.float, torch.float16])
|
||||
def test_mem_eff_attention_pad_mask(self, device, dtype):
|
||||
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype, requires_grad=True)
|
||||
@ -1729,7 +1789,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out = F.scaled_dot_product_attention(query, key, value, mask)
|
||||
out.sum().backward()
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("dtype", [torch.float, torch.float16])
|
||||
def test_mem_eff_attention_non_contiguous_mask(self, device, dtype):
|
||||
make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=dtype, requires_grad=True)
|
||||
@ -1744,7 +1804,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out = F.scaled_dot_product_attention(query, key, value, mask)
|
||||
out.sum().backward()
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("dtype", [torch.float, torch.float16])
|
||||
def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
|
||||
if torch.cuda.get_device_properties('cuda').total_memory < 80 * 2**30:
|
||||
@ -1762,7 +1822,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out.sum().backward()
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("type", ["dense", "nested"])
|
||||
@parametrize("is_contiguous", [True, False])
|
||||
@parametrize("head_dims_match", [True, False])
|
||||
@ -1802,7 +1862,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("type", ["dense", "nested"])
|
||||
@parametrize("is_contiguous", [True, False])
|
||||
def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: str, is_contiguous: bool):
|
||||
@ -1834,13 +1894,11 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous(), atol=2e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("type", ["dense", "nested"])
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_scaled_dot_product_attention_fused_kernels_packed_accuracy(self, device, type: str, fused_kernel: str):
|
||||
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
|
||||
def rand_nt(shape):
|
||||
batch, seq_len, num_heads, head_dim = shape
|
||||
tensors = [6 * torch.rand((seq_len, 3 * num_heads * head_dim), device=device, dtype=torch.float32) - 3
|
||||
@ -1901,7 +1959,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3)
|
||||
self.assertEqual(actual_test, math_ref_test, atol=5e-3, rtol=5e-3)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Flash Attention was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
@parametrize("is_causal", [True, False])
|
||||
def test_sdp_mem_efficient_grad_against_math(self, device, contiguous_inputs: bool, is_causal: bool):
|
||||
@ -1948,7 +2006,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
# Cast up and compare
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=1e-5, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Flash Attention was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention was not built for this system")
|
||||
@parametrize("contiguous_inputs", [True, False])
|
||||
@parametrize("is_causal", [True, False])
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@ -2000,38 +2058,38 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
rtol = 7e-4 if dtype == torch.float16 else 7e-3
|
||||
self.assertEqual(qkv.grad, qkv_lp.grad.to(torch.float64), atol=atol, rtol=rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Platform does not support fused SDPA")
|
||||
@parametrize("type", ["dense", "nested"])
|
||||
def test_fused_sdp_choice(self, device, type: str):
|
||||
if PLATFORM_SUPPORTS_FUSED_SDPA:
|
||||
batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
|
||||
shape = (batch_size, seq_len, num_heads, head_dim)
|
||||
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
|
||||
batch_size, seq_len, num_heads, head_dim = 2, 128, 8, 64
|
||||
shape = (batch_size, seq_len, num_heads, head_dim)
|
||||
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float16, packed=True, requires_grad=True)
|
||||
|
||||
qkv = make_tensor(shape, type=type)
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
qkv = make_tensor(shape, type=type)
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if SM75OrLater and not type == "nested":
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION
|
||||
else:
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
||||
# Change dtype to float32 so that efficient attention should get chosen
|
||||
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
|
||||
|
||||
qkv = make_tensor(shape, type=type)
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
if PLATFORM_SUPPORTS_FLASH_ATTENTION:
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.FLASH_ATTENTION
|
||||
else:
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Platform does not support fused SDPA")
|
||||
# Change dtype to float32 so that efficient attention should get chosen
|
||||
make_tensor = partial(rand_sdpa_tensor, device=device, dtype=torch.float32, packed=True)
|
||||
|
||||
qkv = make_tensor(shape, type=type)
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2)
|
||||
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
|
||||
@parametrize("warn_only", [True, False])
|
||||
def test_sdp_choice_with_determinism(self, device, warn_only):
|
||||
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
|
||||
@ -2043,7 +2101,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
with sdp_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=True):
|
||||
assert torch._fused_sdp_choice(query, key, value) == SDPBackend.EFFICIENT_ATTENTION
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Platform does not support fused SDPA")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA")
|
||||
@parametrize("warn_only", [True, False])
|
||||
def test_mem_eff_backwards_throws_determinism_warning(self, device, warn_only):
|
||||
batch_size, seq_len, num_heads, head_dim = 1, 64, 8, 64
|
||||
@ -2065,7 +2123,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
torch.nn.functional.scaled_dot_product_attention(query, key, value).sum().backward()
|
||||
|
||||
@unittest.skip("This test is not behaving deterministaclly non-deterministaclly on CI/CD")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Platform does not support fused SDPA")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Platform does not support fused SDPA")
|
||||
def test_mem_eff_backwards_determinism(self, device):
|
||||
# Need big seq_len to ensure that num_splits > 1
|
||||
dtype = torch.float32
|
||||
@ -2116,7 +2174,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertFalse(diff_anwser_once)
|
||||
|
||||
# verified passing successfully on H100
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support SDPA")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 256, 512])
|
||||
@ -2217,7 +2275,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support SDPA")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support SDPA")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 312, 512, 1024, 2048] if SM80OrLater else [4, 8, 64, 128, 152, 256, 512])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 65, 128, 256, 408, 512, 1024, 2048] if SM80OrLater else [4, 8, 37, 64, 128, 256, 512])
|
||||
@ -2334,11 +2392,11 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(attn_mask.grad, attn_mask_ref.grad.to(attn_mask.grad.dtype),
|
||||
atol=grad_attn_mask_atol, rtol=grad_attn_mask_rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048])
|
||||
@parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128])
|
||||
@parametrize("seq_len_q", [4, 8, 64, 143, 256, 512, 1024, 2048])
|
||||
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 587, 1024, 2048])
|
||||
@parametrize("head_dim", [8, 16, 21, 32, 64, 72, 96, 128, 160, 192, 203, 256])
|
||||
@parametrize("is_causal", [True, False])
|
||||
@parametrize("dropout_p", [0.0, 0.22, 0.48])
|
||||
@parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@ -2347,6 +2405,8 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
|
||||
scale: str):
|
||||
|
||||
if isSM86or89Device and head_dim in range(193, 256 + 1):
|
||||
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
|
||||
scale = scale if scale is None else (1 / head_dim)
|
||||
n_heads = 4
|
||||
query = torch.rand(batch_size, n_heads, seq_len_q, head_dim,
|
||||
@ -2364,22 +2424,11 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
is_dropout = dropout_p > 0.0
|
||||
|
||||
# Create real output
|
||||
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
|
||||
query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=True)
|
||||
out = output_tuple[0]
|
||||
dbug_mask = output_tuple[-1]
|
||||
|
||||
query_padding_mask = torch.ones(
|
||||
1, seq_len_q, device=device, dtype=torch.bool)
|
||||
key_padding_mask = torch.ones(
|
||||
1, seq_len_k, device=device, dtype=torch.bool)
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
|
||||
dropout_mask = softmax_mask >= 0
|
||||
|
||||
if not is_dropout:
|
||||
# Problem: We pad sizes in the composite region of the top level SDPA. But we need the
|
||||
# Debug mask when have dropout. So I am going to manualy pad up here when testing dropout
|
||||
with sdp_kernel(enable_math=False, enable_flash=True, enable_mem_efficient=False):
|
||||
out = F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal, scale=scale)
|
||||
with sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
|
||||
# High Precision Math Reference
|
||||
out_ref = F.scaled_dot_product_attention(
|
||||
@ -2388,6 +2437,27 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out_lp_ref = F.scaled_dot_product_attention(
|
||||
query_ref_lp, key_ref_lp, value_ref_lp, is_causal=is_causal, scale=scale)
|
||||
else:
|
||||
q_padded, q_og_size = pad_last_dim(query, 8)
|
||||
k_padded, k_og_size = pad_last_dim(key, 8)
|
||||
v_padded, v_og_size = pad_last_dim(value, 8)
|
||||
# scale needs to be calculated on the og head_size
|
||||
if scale is None:
|
||||
scale = 1 / math.sqrt(q_og_size)
|
||||
output_tuple = torch.ops.aten._scaled_dot_product_flash_attention(
|
||||
q_padded, k_padded, v_padded, dropout_p=dropout_p, is_causal=is_causal, scale=scale, return_debug_mask=is_dropout)
|
||||
out = output_tuple[0]
|
||||
out = out[..., :v_og_size]
|
||||
# Build dropout_mask
|
||||
dbug_mask = output_tuple[-1]
|
||||
query_padding_mask = torch.ones(
|
||||
batch_size, seq_len_q, device=device, dtype=torch.bool)
|
||||
key_padding_mask = torch.ones(
|
||||
batch_size, seq_len_k, device=device, dtype=torch.bool)
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim,
|
||||
causal=is_causal)[:, :, :seq_len_q, :seq_len_k]
|
||||
dropout_mask = softmax_mask >= 0
|
||||
# High Precision Math Reference
|
||||
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
||||
query_ref, key_ref, value_ref, dropout_p=dropout_p, is_causal=is_causal, scale=scale, dropout_mask=dropout_mask)[0]
|
||||
@ -2399,7 +2469,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
upstream_grad = torch.rand_like(out, requires_grad=False)
|
||||
|
||||
# backward for flash attention on sm86 and sm89 for headdim > 64 currently disabled
|
||||
if isSM86or89Device and head_dim in range(65, 129):
|
||||
if isSM86or89Device and head_dim in range(193, 256):
|
||||
self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
|
||||
return
|
||||
out.backward(upstream_grad)
|
||||
@ -2407,15 +2477,16 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
out_lp_ref.backward(upstream_grad.to(out_lp_ref.dtype))
|
||||
|
||||
# See [Note] Fused Tolerances above
|
||||
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
|
||||
output_fudge_factor = 3 if head_dim % 8 != 0 else 1
|
||||
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref, output_fudge_factor)
|
||||
|
||||
# TODO: Investigate why grad_q needs larger tolerances
|
||||
query_fudge_factor = 4
|
||||
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(query_ref.grad, query_ref_lp.grad, query_fudge_factor)
|
||||
|
||||
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(key_ref.grad, key_ref_lp.grad)
|
||||
|
||||
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad)
|
||||
value_fudge_factor = 2
|
||||
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(value_ref.grad, value_ref_lp.grad, value_fudge_factor)
|
||||
|
||||
self.assertEqual(out, out_ref.to(out.dtype), atol=output_ref_atol, rtol=output_ref_rtol)
|
||||
self.assertEqual(query.grad, query_ref.grad.to(query.grad.dtype),
|
||||
@ -2425,7 +2496,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater, "Does not support SDPA or pre-SM80 hardware")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware")
|
||||
@parametrize("batch_size", [1, 8])
|
||||
@parametrize("seq_len_q", [256, 512, 1024])
|
||||
@parametrize("seq_len_k", [256, 512, 1024])
|
||||
@ -2456,11 +2527,12 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
return _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len,
|
||||
dropout_p, output_seed, output_offset, device=device)
|
||||
else:
|
||||
dbug_mask = output[-1]
|
||||
# Build dropout_mask
|
||||
dbug_mask = output_tuple[-1]
|
||||
query_padding_mask = torch.ones(
|
||||
1, seq_len_q, device="cuda", dtype=torch.bool)
|
||||
batch_size, seq_len_q, device=device, dtype=torch.bool)
|
||||
key_padding_mask = torch.ones(
|
||||
1, seq_len_k, device="cuda", dtype=torch.bool)
|
||||
batch_size, seq_len_k, device=device, dtype=torch.bool)
|
||||
|
||||
softmax_mask = self.convert_flash_attn_S_to_softmax(
|
||||
dbug_mask, query_padding_mask, key_padding_mask, head_dim=head_dim, causal=is_causal)
|
||||
@ -2495,7 +2567,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
kwargs["compute_log_sumexp"] = True
|
||||
kwargs["attn_bias"] = None
|
||||
if fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
kwargs['return_debug_mask'] = True
|
||||
kwargs['return_debug_mask'] = dropout_p > 0.0
|
||||
with torch.cuda.stream(s):
|
||||
# Create real output
|
||||
output_tuple = fused_op(query, key, value, **kwargs)
|
||||
@ -2583,11 +2655,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
self.assertEqual(value.grad, value_ref.grad.to(value.grad.dtype),
|
||||
atol=grad_v_ref_atol, rtol=grad_v_ref_rtol)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_fused_kernels_seq_len_1_inputs(self, device, fused_kernel):
|
||||
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
|
||||
batch, num_heads, head_dim = 32, 16, 64
|
||||
seq_lens = torch.randint(low=1, high=32, size=(batch,))
|
||||
@ -2617,11 +2688,10 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(torch.float16), atol=1e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("fused_kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
def test_fused_kernels_seq_len_0_inputs(self, device, fused_kernel):
|
||||
if (not SM80OrLater) and fused_kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float16)
|
||||
batch, num_heads, head_dim = 32, 16, 64
|
||||
seq_lens = torch.randint(low=1, high=32, size=(batch,))
|
||||
@ -2644,8 +2714,9 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION])
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system")
|
||||
@parametrize("kernel", [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] if
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION else [SDPBackend.EFFICIENT_ATTENTION])
|
||||
@parametrize("expand_q_batch", [True, False])
|
||||
@parametrize("expand_k_batch", [True, False])
|
||||
@parametrize("expand_v_batch", [True, False])
|
||||
@ -2663,8 +2734,6 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
expand_k_num_heads,
|
||||
expand_v_num_heads,
|
||||
):
|
||||
if (not SM80OrLater) and kernel == SDPBackend.FLASH_ATTENTION:
|
||||
return
|
||||
is_efficient = kernel == SDPBackend.EFFICIENT_ATTENTION
|
||||
dtype = torch.float32 if is_efficient else torch.float16
|
||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype)
|
||||
@ -2733,7 +2802,7 @@ class TestSDPACudaOnly(NNTestCase):
|
||||
|
||||
self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1e-3, rtol=1e-2)
|
||||
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
|
||||
@unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system")
|
||||
def test_fused_kernels_nested_broadcasting_query_dense(self, device):
|
||||
rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32)
|
||||
batch, num_heads, head_dim, head_dim_v = 32, 16, 64, 96
|
||||
|
||||
@ -2764,13 +2764,13 @@
|
||||
output_differentiability: [True, False, False, False]
|
||||
query, key, value, attn_bias: _scaled_dot_product_efficient_attention_backward(grad, query, key, value, attn_bias, output, log_sumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale)
|
||||
|
||||
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor ouput, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
- name: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False, False, False, False, False]
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, ouput, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
query, key, value: _scaled_dot_product_flash_attention_backward(grad, query, key, value, output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
- name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
output_differentiability: [True, False, False, False, False]
|
||||
query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
# - name: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, int? max_q, int? max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None) -> (Tensor output, Tensor query_padded, Tensor key_padded, Tensor value_padded, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
|
||||
# output_differentiability: [True, False, False, False, False, False, False, False]
|
||||
# query, key, value: _flash_attention_backward(grad, query, key, value, output, softmax_logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale)
|
||||
|
||||
# fft
|
||||
- name: _fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor
|
||||
|
||||
@ -4725,16 +4725,14 @@ def meta__scaled_dot_product_flash(
|
||||
head_dim = query.size(3)
|
||||
|
||||
max_seqlen_batch_k = key.size(2)
|
||||
Nnz_q = batch_size * max_seqlen_batch_q
|
||||
|
||||
query_t = query.transpose(1, 2)
|
||||
query_reshaped = query_t.reshape(Nnz_q, num_heads, head_dim)
|
||||
attention = torch.empty_like(query_reshaped, device=query.device)
|
||||
attention = attention.view(
|
||||
batch_size, max_seqlen_batch_q, num_heads, head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
if device_hint(query) == "cpu":
|
||||
Nnz_q = batch_size * max_seqlen_batch_q
|
||||
query_t = query.transpose(1, 2)
|
||||
query_reshaped = query_t.reshape(Nnz_q, num_heads, head_dim)
|
||||
attention = torch.empty_like(query_reshaped, device=query.device)
|
||||
attention = attention.view(
|
||||
batch_size, max_seqlen_batch_q, num_heads, head_dim
|
||||
).transpose(1, 2)
|
||||
logsumexp = torch.empty(
|
||||
(
|
||||
batch_size,
|
||||
@ -4755,9 +4753,12 @@ def meta__scaled_dot_product_flash(
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
torch.empty((), dtype=query.dtype, device=query.device),
|
||||
)
|
||||
max_seqlen_q = math.ceil(max_seqlen_batch_q / 16) * 16
|
||||
|
||||
# Cuda Path
|
||||
query_t = query.transpose(1, 2)
|
||||
attention = torch.empty_like(query_t).transpose(1, 2)
|
||||
logsumexp = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_q),
|
||||
(batch_size, num_heads, max_seqlen_batch_q),
|
||||
dtype=torch.float,
|
||||
device=query.device,
|
||||
)
|
||||
@ -4776,7 +4777,7 @@ def meta__scaled_dot_product_flash(
|
||||
elif max_seqlen_batch_k <= 256:
|
||||
max_seqlen_k = 256
|
||||
debug_mask = torch.empty(
|
||||
(batch_size, num_heads, max_seqlen_q, max_seqlen_k),
|
||||
(batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
@ -4791,8 +4792,8 @@ def meta__scaled_dot_product_flash(
|
||||
return (
|
||||
attention,
|
||||
logsumexp,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
None,
|
||||
None,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
torch.empty((), dtype=torch.long, device="meta"),
|
||||
|
||||
@ -3,7 +3,7 @@ r"""This file is allowed to initialize CUDA context when imported."""
|
||||
import functools
|
||||
import torch
|
||||
import torch.cuda
|
||||
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA
|
||||
from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
|
||||
import inspect
|
||||
import contextlib
|
||||
|
||||
@ -27,7 +27,10 @@ SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_devic
|
||||
SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
|
||||
SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
|
||||
|
||||
PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
|
||||
PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and (not TEST_WITH_ROCM) and (not IS_WINDOWS) and SM80OrLater)
|
||||
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM)
|
||||
# This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
|
||||
|
||||
if TEST_NUMBA:
|
||||
import numba.cuda
|
||||
|
||||
@ -27,7 +27,7 @@ from torch.testing._internal.common_device_type import \
|
||||
toleranceOverride, tol)
|
||||
from torch.testing._internal.common_cuda import (
|
||||
SM53OrLater, SM60OrLater, SM80OrLater, with_tf32_off, TEST_CUDNN,
|
||||
_get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_SDPA
|
||||
_get_torch_cuda_version, _get_torch_rocm_version, PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
make_fullrank_matrices_with_distinct_singular_values,
|
||||
@ -13521,9 +13521,9 @@ op_db: List[OpInfo] = [
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# See [Note] SDPA returns Philox Offset and Seed as tensors that will live on CPU when not in cuda graph capture
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_amp',
|
||||
device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA),
|
||||
device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_crossref_backward_no_amp',
|
||||
device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_SDPA),
|
||||
device_type='cuda', dtypes=(torch.float32,), active_if=PLATFORM_SUPPORTS_FUSED_ATTENTION),
|
||||
# TODO Need to understand what this is testing and why it doesn't work
|
||||
DecorateInfo(unittest.skip("Skipped"), 'TestDecomp', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.skip('output is non-deterministic (when dropout_p > 0)'), 'TestCommon', 'test_compare_cpu'),
|
||||
|
||||
Reference in New Issue
Block a user