flash_attention integration (#81434)

# Summary:
- I added a new submodule Cutlass pointing to 2.10 release. The inclusion of flash_attention code should be gated by the flag: USE_FLASH_ATTENTION. This is defaulted to off resulting in flash to not be build anywhere. This is done on purpose since we don't have A100 machines to compile and test on.

- Only looked at CMake did not attempt bazel or buck yet.

-  I included the mha_fwd from flash_attention that has ben refactored to use cutlass 2.10. There is currently no backwards kernel on this branch. That would be a good follow up.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81434
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2022-09-09 20:11:26 +00:00
committed by PyTorch MergeBot
parent 219ff26172
commit 0fc02dbba4
29 changed files with 4395 additions and 5 deletions

3
.gitmodules vendored
View File

@ -151,3 +151,6 @@
[submodule "third_party/VulkanMemoryAllocator"] [submodule "third_party/VulkanMemoryAllocator"]
path = third_party/VulkanMemoryAllocator path = third_party/VulkanMemoryAllocator
url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git

View File

@ -721,6 +721,13 @@ set(BUILD_ONEDNN_GRAPH OFF)
include(cmake/Dependencies.cmake) include(cmake/Dependencies.cmake)
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF)
if(USE_FLASH_ATTENTION)
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
ENDIF()
if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows")) if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows"))
# CUDA < 10.2 doesn't support compiling and extracting header dependencies in # CUDA < 10.2 doesn't support compiling and extracting header dependencies in
# one call, so instead CMake calls nvcc twice with && in between. # one call, so instead CMake calls nvcc twice with && in between.

View File

@ -130,15 +130,13 @@ file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh")
file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp") file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp")
file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh") file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh")
file(GLOB native_cudnn_cpp "native/cudnn/*.cpp") file(GLOB native_cudnn_cpp "native/cudnn/*.cpp")
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu") file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu")
file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp") file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu") file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu")
file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp") file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp")
file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp") file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp")
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu") file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp") file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
file(GLOB native_hip_hip "native/hip/*.hip") file(GLOB native_hip_hip "native/hip/*.hip")
file(GLOB native_hip_cpp "native/hip/*.cpp") file(GLOB native_hip_cpp "native/hip/*.cpp")
@ -151,11 +149,22 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp") file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip") file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp") file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip") file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp") file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp") file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
file(GLOB native_utils_cpp "native/utils/*.cpp") file(GLOB native_utils_cpp "native/utils/*.cpp")
# flash_attention sources
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
endif()
# XNNPACK # XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp") file(GLOB native_xnnpack "native/xnnpack/*.cpp")
@ -415,6 +424,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif() endif()
if(USE_CUDA AND NOT USE_ROCM) if(USE_CUDA AND NOT USE_ROCM)
if(USE_FLASH_ATTENTION)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
endif()
if($ENV{ATEN_STATIC_CUDA}) if($ENV{ATEN_STATIC_CUDA})
list(APPEND ATen_CUDA_DEPENDENCY_LIBS list(APPEND ATen_CUDA_DEPENDENCY_LIBS
${CUDA_LIBRARIES} ${CUDA_LIBRARIES}

View File

@ -13136,6 +13136,11 @@
structured: True structured: True
variants: function variants: function
- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool causal) -> Tensor
variants: function
dispatch:
CUDA: flash_scaled_dot_product_attention
- func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor) - func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor)
variants: function variants: function
dispatch: dispatch:

View File

@ -7,6 +7,7 @@
#include <c10/util/string_view.h> #include <c10/util/string_view.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/util/Optional.h>
namespace at { namespace at {
namespace native { namespace native {
@ -243,5 +244,196 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c
} }
return result; return result;
} }
std::tuple<Tensor, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
TORCH_CHECK(
qkv.is_nested(),
"QKV must be nested for flash cumulative_seq_len calculation.")
auto* nt_impl = get_nested_tensor_impl(qkv);
const auto& sizes = nt_impl->get_nested_size_tensor();
auto size_tensor_stride = sizes.stride(0);
const int64_t batch_size = qkv.size(0);
auto cumulative_seqlen = at::zeros(
{batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt));
auto* sizes_ptr = sizes.data_ptr<int64_t>();
auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr<int32_t>();
int32_t sum = 0;
int64_t max_seqlen = -1;
cumulative_seqlen_ptr[0] = sum;
for (const auto i : c10::irange(batch_size)) {
// Calculate the cumulative sum of the sequence lengths
auto current_seq_len = sizes_ptr[i * size_tensor_stride];
sum += current_seq_len;
cumulative_seqlen_ptr[i + 1] = sum;
// Find the max element while we traverse
max_seqlen = std::max(max_seqlen, current_seq_len);
}
// Send to GPU, this is pretty light weight calc for normal batch size
// but maybe this needs to be on gpu
cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA));
return std::tuple<Tensor, int64_t>{cumulative_seqlen, max_seqlen};
}
Tensor flash_attention_helper(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool causal) {
// Query is of size (batch_size x ragged_seq_len x (3 or 1) x n_heads x
// head_did
int64_t head_dim{query.size(-1)};
int64_t num_heads{query.size(-2)};
auto cumulative_and_max_q = cumulative_and_max_seq_len(query);
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
if (key.is_same(value) || query.is_same(key) || query.is_same(value)) {
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
// For the packed case we need to set the output size for dim 2 to 1
auto atten_size = get_nested_size_tensor(query);
atten_size.index({at::indexing::Slice(), 1}) = 1;
auto qkv_buffer_reshaped =
get_buffer(query).view({Nnz_q, 3, num_heads, head_dim});
// If we are passing in query, key, value all the same tensors than we have
// packed them into one tensor and need to slice for flash attention
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
qkv_buffer_reshaped.index({at::indexing::Slice(), 0}),
qkv_buffer_reshaped.index({at::indexing::Slice(), 1}),
qkv_buffer_reshaped.index({at::indexing::Slice(), 2}),
cumulative_sequence_length_q,
cumulative_sequence_length_q,
max_seqlen_batch_q,
max_seqlen_batch_q,
dropout_p,
causal);
// Output of flash_attention is a regular tensor lets wrap it back up to
// form a nested tensor
return wrap_buffer(atten_buffer.view(-1), atten_size);
}
// Query, Key, and Value are not all the same tensor and therefore need to
// calculate K meta data
// The nested tensors will be of shape {Batch_size x ragged_seq_len x
// num_heads * head_dim }
auto cumulative_and_max_k = cumulative_and_max_seq_len(key);
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k);
int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k);
// K and V have to have the same Nnz, should probably torch_check before now
// assume in order to not iterate over v
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
int64_t Nnz_kv{cumulative_sequence_length_k[-1].item<int64_t>()};
auto query_buffer_reshaped =
get_buffer(query).view({Nnz_q, num_heads, head_dim});
auto key_buffer_reshaped =
get_buffer(key).view({Nnz_kv, num_heads, head_dim});
auto value_buffer_reshaped =
get_buffer(value).view({Nnz_kv, num_heads, head_dim});
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
query_buffer_reshaped,
key_buffer_reshaped,
value_buffer_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
causal);
// Output of flash_attention is a regular tensor lets wrap it back up to
// form a nested tensor, the size of which should match the query tensor
return wrap_buffer(atten_buffer.view(-1), get_nested_size_tensor(query));
}
Tensor flash_attention_helper_dense(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool causal) {
TORCH_INTERNAL_ASSERT(
!query.is_nested() && !key.is_nested() && !value.is_nested());
// Query is of size (batch_size x dense_seq_len x 3 x n_heads
// head_dim)
const auto batch_size = query.size(0);
auto max_seqlen_batch_q = query.size(1);
int64_t head_dim{query.size(-1)};
int64_t num_heads{query.size(-2)};
auto cumulative_sequence_length_q = at::arange(
0,
(batch_size + 1) * max_seqlen_batch_q,
max_seqlen_batch_q,
TensorOptions().device(at::kCUDA).dtype(at::kInt));
int64_t Nnz_q{batch_size * max_seqlen_batch_q};
if (key.is_same(value) || query.is_same(key) || query.is_same(value)) {
// In the dense case flash attention expects an input that is
// (b*s) x num_heads x head_dim
auto query_reshaped = query.reshape({Nnz_q, 3, num_heads, head_dim});
// If we are passing in query, key, value all the same tensors than we have
// packed them into one tensor and need to slice for flash attention
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
query_reshaped.index({at::indexing::Slice(), 0}),
query_reshaped.index({at::indexing::Slice(), 1}),
query_reshaped.index({at::indexing::Slice(), 2}),
cumulative_sequence_length_q,
cumulative_sequence_length_q,
max_seqlen_batch_q,
max_seqlen_batch_q,
dropout_p,
causal);
// Reshape output to convert nnz to batch_size and seq_len
return atten_buffer.reshape(
{batch_size, max_seqlen_batch_q, num_heads, head_dim});
}
// Query, Key, and Value are not all the same tensor and therefore need to
// calculate K meta data
auto max_seqlen_batch_k = key.size(1);
auto cumulative_sequence_length_k = at::arange(
0,
(batch_size + 1) * max_seqlen_batch_k,
max_seqlen_batch_k,
TensorOptions().device(at::kCUDA).dtype(at::kInt));
// K and V have to have the same Nnz, should probably torch_check before
// assume for now in order to not iterate over v
int64_t Nnz_kv{batch_size * max_seqlen_batch_k};
// Calculate head dim
TORCH_INTERNAL_ASSERT(query.size(-1) == key.size(-1));
TORCH_INTERNAL_ASSERT(query.size(-1) == value.size(-1));
auto query_reshaped = query.reshape({Nnz_q, num_heads, head_dim});
auto key_reshaped = key.reshape({Nnz_kv, num_heads, head_dim});
auto value_reshaped = value.reshape({Nnz_kv, num_heads, head_dim});
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
query_reshaped,
key_reshaped,
value_reshaped,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
causal);
// Reshape output to convert nnz to batch_size and seq_len
return atten_buffer.reshape(
{batch_size, max_seqlen_batch_q, num_heads, head_dim});
}
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -83,5 +83,19 @@ void add_padding_kernelLauncher(
const std::vector<int64_t>& output_sizes, const std::vector<int64_t>& output_sizes,
const int batch_size, const int batch_size,
const int output_batch_size); const int output_batch_size);
Tensor flash_attention_helper_dense(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool causal);
Tensor flash_attention_helper(
const Tensor& query,
const Tensor& key,
const Tensor& value,
double dropout_p,
bool causal);
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -1,4 +1,5 @@
#include <type_traits> #include <type_traits>
#include <c10/util/Exception.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/NestedTensorImpl.h> #include <ATen/NestedTensorImpl.h>
@ -9,10 +10,18 @@
#include <ATen/ops/_nested_from_padded.h> #include <ATen/ops/_nested_from_padded.h>
#endif #endif
// TODO Consider moving all flash_attention code, nested tensor included to
// Transformer library
#ifdef USE_FLASH_ATTENTION
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
#endif
#include <ATen/native/nested/NestedTensorTransformerFunctions.h> #include <ATen/native/nested/NestedTensorTransformerFunctions.h>
#include <ATen/native/nested/NestedTensorMath.h> #include <ATen/native/nested/NestedTensorMath.h>
#include <ATen/native/nested/NestedTensorUtils.h> #include <ATen/native/nested/NestedTensorUtils.h>
#include <ATen/cuda/CUDAContext.h>
namespace at { namespace at {
namespace native { namespace native {
namespace { namespace {
@ -207,5 +216,37 @@ Tensor NestedTensor_to_padded_tensor_cuda(
return NestedTensor_to_padded_tensor_generic(t, padding, output_size); return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
} }
Tensor flash_scaled_dot_product_attention(
const Tensor& query,
const Tensor& key,
const Tensor& value,
const Tensor& cumulative_sequence_length_q,
const Tensor& cumulative_sequence_length_k,
const int64_t max_seqlen_batch_q,
const int64_t max_seqlen_batch_k,
double dropout_p,
bool causal) {
#if defined(USE_FLASH_ATTENTION)
auto softmax_scale = std::pow(query.size(-1), -0.5);
std::vector<Tensor> output = fmha::mha_fwd(
query,
key,
value,
cumulative_sequence_length_q,
cumulative_sequence_length_k,
max_seqlen_batch_q,
max_seqlen_batch_k,
dropout_p,
softmax_scale,
false,
causal,
false,
c10::nullopt);
return output[0];
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
return Tensor{};
}
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -0,0 +1,149 @@
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
******************************************************************************/
#pragma once
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
#include <third_party/cutlass/include/cutlass/layout/layout.h>
#include <third_party/cutlass/include/cutlass/arch/mma.h>
#include <third_party/cutlass/include/cutlass/array.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
#include <ATen/native/transformers/cuda/flash_attn/epilogue_predicated_tile_iterator.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename MmaCore>
struct FMHAEpilogue {
using ThreadblockShape = typename MmaCore::Shape;
using WarpMma = typename MmaCore::MmaTensorOp;
using LayoutC = typename MmaCore::LayoutC;
using Element = typename MmaCore::ElementA;
using ElementC = typename MmaCore::ElementC;
static constexpr int kPartitionsK = ThreadblockShape::kK / MmaCore::WarpShape::kK;
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
typename WarpMma::Shape,
typename WarpMma::Policy::Operator::Shape,
typename WarpMma::Policy::Operator::ElementC,
typename WarpMma::Policy::Operator::FragmentC,
LayoutC>;
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
static constexpr int kIterationsStore = AccumulatorFragmentIterator::kIterations;
// Maybe elementsPerAccess should vary: 4 for d=64, 2 for d=32?
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, Element, /*ElementsPerAccess=*/4>::Type;
using OutputTileThreadMapAccum = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, ElementC, /*ElementsPerAccess=*/4>::Type;
using GmemIterator = fmha::EpiloguePredicatedTileIterator<
OutputTileThreadMap,
Element
>;
// which ThreadMap should we use?
using GmemIteratorAccum = fmha::EpiloguePredicatedTileIterator<
// OutputTileThreadMapAccum,
OutputTileThreadMap,
ElementC
>;
using DefaultIterators = cutlass::epilogue::threadblock::detail::DefaultIteratorsTensorOp<
Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape,
typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>;
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
static_assert(WarpTileIterator::kIterations == kIterationsStore);
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
using OutputFragment = typename SharedLoadIterator::Fragment;
// using Padding = cutlass::MatrixShape<0, 0>;
using Padding = cutlass::MatrixShape<0, 64 / cutlass::sizeof_bits<ElementC>::value * 4>;
static constexpr int kFragmentsPerIteration = kIterationsStore; // TODO: could be 1 for Volta?
/*Using kIterationsStore here so that we get the right storage size*/
using EpilogueBase = typename cutlass::epilogue::threadblock::EpilogueBase<
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, AccumulatorFragmentIterator, WarpTileIterator,
Padding, kIterationsStore>;
using SharedStorage = typename EpilogueBase::SharedStorage;
static constexpr int kSmemTiles = EpilogueBase::kFragmentsPerIteration;
static constexpr int kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
static constexpr int kSmemPointerOffsetPerWarp = SharedStorage::StorageShape::kCount / (kSmemTiles * kPartitionsK);
SharedStorage *shared_storage;
WarpTileIterator warp_tile_iterator;
inline __device__ FMHAEpilogue(void *smem, const int tidx)
: shared_storage(reinterpret_cast<SharedStorage *>(smem))
, warp_tile_iterator(shared_storage->reference(), threadIdx.x % 32) {
// const int warp_idx = tidx / 32;
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
// https://github.com/NVIDIA/cutlass/blob/e66bfcb1f880792caa46b1e983c4114e23afa5f3/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h#L520
const int warp_idx = __shfl_sync(0xffffffff, tidx / 32, 0);
cutlass::MatrixCoord warp_offset{kIterationsStore * warp_idx, 0};
warp_tile_iterator.add_tile_offset(warp_offset);
}
// Store the accumulators.
inline __device__ void store(const AccumulatorTile &acc) {
AccumulatorFragmentIterator accum_fragment_iterator(acc);
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < kIterationsStore; ++p) {
typename AccumulatorFragmentIterator::Fragment accum_fragment;
accum_fragment_iterator.load(accum_fragment);
++accum_fragment_iterator;
warp_tile_iterator.store(accum_fragment);
if (p < kIterationsStore - 1) {
warp_tile_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp);
}
}
if (kIterationsStore > 1) {
warp_tile_iterator.add_pointer_offset((1 - kIterationsStore) * kSmemPointerOffsetPerWarp);
}
}
// Load the accumulators
template<bool zero_init=true>
inline __device__ void load(OutputFragment (&out)[kFragmentsPerIteration],
const int tidx) {
SharedLoadIterator shared_load_iterator(shared_storage->reference(), tidx);
CUTLASS_PRAGMA_UNROLL
for (int p = 0; p < EpilogueBase::kFragmentsPerIteration; ++p) {
OutputFragment aligned_accum_fragment[kPartitionsK];
shared_load_iterator.load(aligned_accum_fragment[0]);
cutlass::plus<OutputFragment> add_fragments;
if (kPartitionsK > 1) {
CUTLASS_PRAGMA_UNROLL
for ( int i = 1; i < kPartitionsK; ++i) {
shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp * kIterationsStore);
shared_load_iterator.load(aligned_accum_fragment[i]);
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
}
shared_load_iterator.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffsetPerWarp * kIterationsStore);
}
if (p < EpilogueBase::kFragmentsPerIteration - 1) {
shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp);
}
out[p] = zero_init ? aligned_accum_fragment[0] : add_fragments(out[p], aligned_accum_fragment[0]);
}
}
};
} // namespace fmha

View File

@ -0,0 +1,493 @@
// Adapted from cutlass/epilogue/threadblock/predicated_tile_iterator.h
// We just want to add the move() function, but idk how to do it without
// copying the code here.
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/arch/arch.h>
#include <third_party/cutlass/include/cutlass/arch/memory.h>
#include <third_party/cutlass/include/cutlass/array.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h>
#include <third_party/cutlass/include/cutlass/layout/matrix.h>
#include <third_party/cutlass/include/cutlass/layout/tensor.h>
#include <third_party/cutlass/include/cutlass/matrix_shape.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
#include <third_party/cutlass/include/cutlass/tensor_ref.h>
#include <third_party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////
using namespace cutlass;
using namespace cutlass::epilogue::threadblock;
////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load and store output tile from global memory in epilogue.
///
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
///
template <
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
typename Element_, ///< Element data type
bool ScatterD = false, ///< Scatter D operand or not
bool UseCUDAStore = false
>
class EpiloguePredicatedTileIterator {
public:
using ThreadMap = ThreadMap_;
using Shape = typename ThreadMap::Shape;
using Element = Element_;
using Layout = layout::RowMajor;
using TensorRef = TensorRef<Element, Layout>;
using ConstTensorRef = typename TensorRef::ConstTensorRef;
using Index = typename Layout::Index;
using LongIndex = typename Layout::LongIndex;
using TensorCoord = MatrixCoord;
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
static int const kThreads = ThreadMap::kThreads;
static int const kIterations = ThreadMap::Count::kTile;
static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");
static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");
static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");
static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");
/// Fragment object
using Fragment = Array<
Element,
ThreadMap::Iterations::kColumn *
ThreadMap::Iterations::kRow *
ThreadMap::Iterations::kGroup *
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
/// Memory access size
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
//
// Parameters struct
//
/// Uses a non-template class
struct Params : PredicatedTileIteratorParams {
using Base = PredicatedTileIteratorParams;
CUTLASS_HOST_DEVICE
Params() { }
CUTLASS_HOST_DEVICE
Params(Layout const &layout):
PredicatedTileIteratorParams(
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
make_OutputTileThreadMapDesc<ThreadMap>()
)
{ }
CUTLASS_HOST_DEVICE
Params(Base const &base) :
Base(base) { }
};
/// Mask object
struct Mask {
static int const kCount = ThreadMap::Iterations::kColumn;
/// Predicate state
bool predicates[kCount];
//
// Mask
//
CUTLASS_HOST_DEVICE
Mask() {
enable();
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_HOST_DEVICE void clear() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = false;
}
}
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
CUTLASS_DEVICE void enable() {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kCount; ++i) {
predicates[i] = true;
}
}
};
private:
//
// Data members
//
/// Parameters structure containing reference and precomputed state.
PredicatedTileIteratorParams params_;
/// Byte-level pointer
uint8_t *byte_pointer_;
/// Array of boolean values to contain steady-state predicates
Mask mask_;
/// Extent of the matrix tile in rows
Index extent_row_;
/// Extent of the matrix tile in rows
Index extent_column_;
/// A thread's starting row position (assuming steady-state predicates have been computed)
Index thread_start_row_;
/// A thread's starting column
Index thread_start_column_;
/// Internal state counter
int state_[3];
/// Scatter indices
int const *indices_;
//
// Static asserts about internal strides
//
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
private:
//
// Methods
//
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
EpiloguePredicatedTileIterator(
PredicatedTileIteratorParams const & params,
Element *pointer,
TensorCoord extent,
int thread_idx,
TensorCoord threadblock_offset = TensorCoord(),
int const *indices = nullptr
):
params_(params), indices_(indices)
{
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
extent_row_ = extent.row();
extent_column_ = extent.column();
thread_start_row_ = thread_offset.row();
thread_start_column_ = thread_offset.column();
// Initialize predicates
CUTLASS_PRAGMA_UNROLL
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
mask_.predicates[c] = ((thread_offset.column()
+ ThreadMap::Delta::kColumn * c) < extent.column());
}
// Null pointer performs no accesses
if (!pointer) {
mask_.clear();
}
if (ScatterD && !indices) {
mask_.clear();
}
// Initialize pointer
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
if (ScatterD) {
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
}
// Initialize internal state counter
state_[0] = state_[1] = state_[2] = 0;
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void add_pointer_offset(LongIndex pointer_offset) {
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const {
uint8_t *byte_pointer = byte_pointer_;
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
cutlass::arch::global_load<
AccessType,
sizeof(AccessType)
>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn /
kElementsPerAccess],
guard);
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void load(Fragment &frag) const {
load_with_byte_offset(frag, 0);
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
uint8_t *byte_pointer = byte_pointer_;
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
CUTLASS_PRAGMA_UNROLL
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
CUTLASS_PRAGMA_UNROLL
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
CUTLASS_PRAGMA_UNROLL
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
int frag_row_idx =
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
int row_offset = row * ThreadMap::Delta::kRow
+ group * ThreadMap::Delta::kGroup
+ cluster * ThreadMap::Delta::kCluster;
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
if (ScatterD && row_guard) {
assert(indices_);
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset +
LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride));
}
CUTLASS_PRAGMA_UNROLL
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
bool guard = row_guard && mask_.predicates[column];
if (UseCUDAStore) {
if (guard) {
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
}
} else {
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
guard);
}
}
if (row + 1 < ThreadMap::Iterations::kRow) {
if (!ScatterD) {
byte_pointer += params_.increment_row;
}
}
}
if (group + 1 < ThreadMap::Iterations::kGroup) {
byte_pointer += params_.increment_group;
}
}
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
byte_pointer += params_.increment_cluster;
}
}
}
/// Stores a fragment to memory
CUTLASS_DEVICE
void store(Fragment const &frag) const {
store_with_byte_offset(frag, 0);
}
CUTLASS_DEVICE
MatrixCoord thread_start() const {
return MatrixCoord(thread_start_row_, thread_start_column_);
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_row() const {
return thread_start_row_;
}
/// Need to get the thread start row from the tile iterator
CUTLASS_DEVICE
int32_t thread_start_column() const {
return thread_start_column_;
}
/// Extent of the matrix in rows
CUTLASS_DEVICE
Index extent_row() const {
return extent_row_;
}
/// Extent of the matrix in columns
CUTLASS_DEVICE
Index extent_column() const {
return extent_column_;
}
/// Advances to the next position to load or store
CUTLASS_HOST_DEVICE
void move(const int step=1) {
if (!ScatterD) {
byte_pointer_ += step * params_.advance_row;
}
thread_start_row_ += step * ThreadMap::Shape::kRow;
}
///< Efficiently disables all accesses guarded by mask
CUTLASS_DEVICE void clear_mask() {
mask_.clear();
}
///< Efficiently enables all accesses guarded by mask
CUTLASS_DEVICE void enable_mask() {
mask_.enable();
}
///< Sets the mask
CUTLASS_DEVICE void get_mask(Mask &mask) const {
mask = mask_;
}
///< Sets the mask
CUTLASS_DEVICE void set_mask(Mask const &mask) {
mask_ = mask;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,154 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cuda.h>
#include <vector>
#include <ATen/cuda/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <ATen/native/transformers/cuda/flash_attn/fmha_utils.h>
constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
// The QKV matrices.
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
// The stride between rows of the Q, K and V matrices.
// size_t qkv_stride_in_elts;
// size_t qkv_stride_in_bytes;
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
uint32_t q_row_stride_in_elts;
uint32_t k_row_stride_in_elts;
uint32_t v_row_stride_in_elts;
uint32_t q_head_stride_in_elts;
uint32_t k_head_stride_in_elts;
uint32_t v_head_stride_in_elts;
// The number of heads.
int h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct FMHA_fprop_params : public Qkv_params {
// The O matrix (output).
void * __restrict__ o_ptr;
// The stride between rows of O.
// size_t o_stride_in_elts;
// size_t o_stride_in_bytes;
uint32_t o_row_stride_in_elts;
uint32_t o_head_stride_in_elts;
// The pointer to the O_tmp matrix, which holds O intermediate value during
// the loop;
void *__restrict__ o_tmp_ptr;
// The pointer to the S matrix.
void * __restrict__ s_ptr;
// The stride between rows of the S matrix.
// int64_t s_stride_in_bytes;
uint32_t s_stride_in_bytes;
// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
// The dimensions.
int b, seqlen_q, seqlen_k, d;
// The scaling factors for the kernel.
float scale_bmm1;
// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int *__restrict__ blockmask;
// The dropout probability (probability of keeping an activation).
float p_dropout;
uint32_t p_dropout_in_uint;
uint16_t p_dropout_in_uint16_t;
// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_bmm1_rp_dropout;
// Random state.
at::PhiloxCudaState philox_args;
bool is_bf16;
bool is_causal;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_params>
struct Launch_params{
Launch_params(cudaDeviceProp * props_,
cudaStream_t stream_,
bool is_dropout_,
bool return_softmax_)
: elts_per_thread(0)
, props(props_)
, stream(stream_)
, is_dropout(is_dropout_)
, return_softmax(return_softmax_) {
}
size_t elts_per_thread;
cudaDeviceProp * props;
cudaStream_t stream;
bool is_dropout;
bool return_softmax;
Kernel_params params;
int num_full_heads;
int num_main_groups;
int heads_last_wave;
int main_steps;
int rest_steps;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);

View File

@ -0,0 +1,244 @@
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/NativeFunctions.h>
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
#include <c10/util/Exception.h>
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
namespace fmha {
void set_params_fprop(FMHA_fprop_params &params,
// sizes
const size_t b,
const size_t seqlen_q,
const size_t seqlen_k,
const size_t h,
const size_t d,
// device pointers
const at::Tensor q,
const at::Tensor k,
const at::Tensor v,
void *cu_seqlens_q_d,
void *cu_seqlens_k_d,
void *o_packed_d,
void *o_tmp_d,
void *s_d,
void *softmax_lse_d,
float p_dropout,
float softmax_scale,
bool is_causal) {
// Reset the parameters
memset(&params, 0, sizeof(params));
params.is_bf16 = q.dtype() == at::kBFloat16;
// Set the pointers and strides.
params.q_ptr = q.data_ptr();
params.k_ptr = k.data_ptr();
params.v_ptr = v.data_ptr();
params.q_row_stride_in_elts = q.stride(0);
params.k_row_stride_in_elts = k.stride(0);
params.v_row_stride_in_elts = v.stride(0);
params.q_head_stride_in_elts = q.stride(1);
params.k_head_stride_in_elts = k.stride(1);
params.v_head_stride_in_elts = v.stride(1);
params.o_ptr = o_packed_d;
params.o_row_stride_in_elts = h * d;
params.o_head_stride_in_elts = d;
params.o_tmp_ptr = o_tmp_d;
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
// S = softmax(P)
params.s_ptr = s_d;
params.s_stride_in_bytes = b * h * seqlen_k * 2; // 2 = sizeof(Element)
// Softmax sum
params.softmax_lse_ptr = softmax_lse_d;
// Set the dimensions.
params.b = b;
params.h = h;
params.seqlen_q = seqlen_q;
params.seqlen_k = seqlen_k;
params.d = d;
// Set the different scale values.
params.scale_bmm1 = softmax_scale;
// Set this to probability of keeping an element to simplify things.
params.p_dropout = 1.f - p_dropout;
// Convert p from float to int so we don't have to convert the random uint to float to compare.
// [Minor] We want to round down since when we do the comparison we use <= instead of <
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
params.rp_dropout = 1.f / params.p_dropout;
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1;
TORCH_CHECK(p_dropout < 1.f);
params.is_causal = is_causal;
}
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_) {
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
TORCH_CHECK(is_sm8x || is_sm75);
auto stream = at::cuda::getCurrentCUDAStream().stream();
bool is_dropout = p_dropout > 0.0;
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
auto q_dtype = q.dtype();
TORCH_CHECK(q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16));
TORCH_CHECK(k.dtype() == q_dtype);
TORCH_CHECK(v.dtype() == q_dtype);
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt);
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt);
TORCH_CHECK(q.is_cuda());
TORCH_CHECK(k.is_cuda());
TORCH_CHECK(v.is_cuda());
TORCH_CHECK(cu_seqlens_q.is_cuda());
TORCH_CHECK(cu_seqlens_k.is_cuda());
TORCH_CHECK(q.stride(-1) == 1);
TORCH_CHECK(k.stride(-1) == 1);
TORCH_CHECK(v.stride(-1) == 1);
TORCH_CHECK(cu_seqlens_k.is_contiguous());
TORCH_CHECK(cu_seqlens_k.is_contiguous());
const auto sizes = q.sizes();
const int batch_size = cu_seqlens_q.numel() - 1;
const int total_q = sizes[TOTAL_DIM];
const int num_heads = sizes[H_DIM];
const int head_size = sizes[D_DIM];
const int total_k = k.size(TOTAL_DIM);
TORCH_CHECK(batch_size > 0);
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
const int head_size_rounded = head_size <= 64 ? 64 : 128;
CHECK_SHAPE(q, total_q, num_heads, head_size);
CHECK_SHAPE(k, total_k, num_heads, head_size);
CHECK_SHAPE(v, total_k, num_heads, head_size);
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
int blocksize_c = ((head_size_rounded == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size_rounded == 64 && is_dropout)) ? 128 : 256;
// Need to round max_seqlen_k to multiples of blocksize_c
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
if( max_seqlen_k_ <= 128 ) {
max_seqlen_k = 128;
} else if( max_seqlen_k_ <= 256 ) {
max_seqlen_k = 256;
}
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
bool loop = max_seqlen_k > blocksize_c;
auto opts = q.options();
auto o = at::empty({ total_q, num_heads, head_size }, opts);
at::Tensor o_tmp;
if (loop) { o_tmp = at::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
// auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
at::Tensor s;
if (return_softmax) { s = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
if( zero_tensors ) {
o.zero_();
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
if (return_softmax) {s.zero_();}
}
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());
set_params_fprop(launch_params.params,
batch_size,
max_seqlen_q,
max_seqlen_k,
num_heads,
head_size,
q, k, v,
cu_seqlens_q.data_ptr(),
cu_seqlens_k.data_ptr(),
o.data_ptr(),
loop ? o_tmp.data_ptr() : nullptr,
return_softmax ? s.data_ptr() : nullptr,
softmax_lse.data_ptr(),
p_dropout,
softmax_scale,
is_causal);
run_fmha_fprop(launch_params, /*configure=*/ true);
// number of times random will be generated per thread, to offset philox counter in thc random
// state
int64_t counter_offset = launch_params.elts_per_thread;
at::PhiloxCudaState rng_engine_inputs;
if( is_dropout ) {
// See Note [Acquire lock when using random generators]
std::lock_guard<std::mutex> lock(gen->mutex_);
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
}
run_fmha_fprop(launch_params, /*configure=*/false);
std::vector<at::Tensor> result = {o, softmax_lse};
if (return_softmax) {result.push_back(s);}
return result;
}
} // namespace fmha

View File

@ -0,0 +1,24 @@
#pragma once
#include <cstddef>
#include <ATen/ATen.h>
#include <c10/util/Exception.h>
namespace fmha {
std::vector<at::Tensor>
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
const int max_seqlen_q_,
const int max_seqlen_k_,
const float p_dropout,
const float softmax_scale,
const bool zero_tensors,
const bool is_causal,
const bool return_softmax,
c10::optional<at::Generator> gen_);
} // namespace fmha

View File

@ -0,0 +1,722 @@
/***************************************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
#include <ATen/native/transformers/cuda/flash_attn/fmha_kernel.h>
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
#include <ATen/native/transformers/cuda/flash_attn/epilogue.h>
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/layout/layout.h>
#include <third_party/cutlass/include/cutlass/array.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
#include <third_party/cutlass/include/cutlass/numeric_conversion.h>
#include <third_party/cutlass/include/cutlass/arch/mma.h>
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
#include <third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h>
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h>
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h>
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h>
#include <third_party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits>
struct Gemm_Q_K_base {
using Smem_O = fmha::FMHAEpilogue<typename Kernel_traits::MmaCorePV>;
using WarpMma = typename Kernel_traits::MmaCoreQK::MmaTensorOp;
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
static constexpr size_t SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k)
: smem_q_ptr(smem_ptr_q)
, smem_k_ptr(smem_ptr_k) {
}
__device__ inline void load_q(int byte_offset=0) {
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K});
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32);
iter_A.load(frag_q[0]);
}
__device__ inline void reload_q(int byte_offset=0) {
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K});
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32);
iter_A.load(frag_q[0]);
}
typename WarpMma::FragmentA frag_q[2];
char *smem_q_ptr;
char *smem_k_ptr;
};
template<typename Kernel_traits, bool K_in_regs>
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Cta_tile_p = typename Base::Cta_tile_p;
using Smem_O = typename Base::Smem_O;
using WarpMma = typename Base::WarpMma;
static constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
// If V is stored in shared memory, we can't load K using the same shared memory.
static_assert(Kernel_traits::V_IN_REGS);
static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q;
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
// Q | K / V
// | O | SOFTMAX
static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q
+ std::max((size_t)(SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K,
sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX);
__device__ inline Gemm_Q_K(char * smem_)
: Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) {
}
__device__ inline void load_k(){
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
const int warp_idx = threadIdx.x / 32;
iter_B.add_tile_offset({0, warp_idx});
#pragma unroll
for( int ki = 0; ki < kIterations; ++ki ) {
iter_B.load(frag_k[ki]);
++iter_B;
}
}
__device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K});
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32);
++iter_A;
// Do this part of P^T = (Q * K^T)^T.
#pragma unroll
for( int ki = 0; ki < kIterations; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
if (ki + 1 < kIterations) { iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A; }
// Do the math for the values already in registers.
warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki], acc_p);
}
}
__device__ inline void reload_k(){
// Noop.
}
typename WarpMma::FragmentB frag_k[kIterations];
};
template<typename Kernel_traits>
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
using Base = Gemm_Q_K_base<Kernel_traits>;
using Cta_tile_p = typename Base::Cta_tile_p;
using Smem_O = typename Base::Smem_O;
using WarpMma = typename Base::WarpMma;
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V;
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K
+ sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX;
__device__ inline Gemm_Q_K(char * smem_)
: Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) {
}
__device__ inline void load_k(){
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
const int warp_idx = threadIdx.x / 32;
iter_B.add_tile_offset({0, warp_idx});
iter_B.load(frag_k[0]);
}
__device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K});
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32);
++iter_A;
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
const int warp_idx = threadIdx.x / 32;
iter_B.add_tile_offset({0, warp_idx});
++iter_B;
// Do this part of P^T = (Q * K^T)^T.
constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK;
#pragma unroll
for( int ki = 0; ki < kIterations; ++ki ) {
// Trigger the load from shared memory for the next series of Q values.
if (ki + 1 < kIterations) {
iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A;
iter_B.load(frag_k[(ki + 1) % 2]); ++iter_B;
}
// Do the math for the values already in registers.
warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki % 2], acc_p);
}
}
__device__ inline void reload_k(){
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
const int warp_idx = threadIdx.x / 32;
iter_B.add_tile_offset({0, warp_idx});
iter_B.load(frag_k[0]);
}
typename WarpMma::FragmentB frag_k[2];
};
template<typename Kernel_traits>
constexpr size_t get_dynamic_smem_size(){
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
inline __device__ void device_1xN_(const Params &params, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
// The description of the CTA tile for the 1st batched GEMM.
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
// The description of the CTA tile for the 2nd batched GEMM.
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
// The MMA tile for the 1st GEMM.
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
// The MMA tile for the 2nd GEMM.
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
using InstructionShape = typename Kernel_traits::MmaInstructionShape;
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using ThreadblockShapeQK = typename Kernel_traits::ThreadblockShapeQK;
using LayoutQ = typename Kernel_traits::LayoutQ;
using LayoutK = typename Kernel_traits::LayoutK;
using LayoutP = typename Kernel_traits::LayoutP;
using MmaCoreQK = typename Kernel_traits::MmaCoreQK;
using WarpMmaQK = typename MmaCoreQK::MmaTensorOp;
using SmemLayoutQ = typename MmaCoreQK::SmemLayoutA;
using SmemLayoutK = typename MmaCoreQK::SmemLayoutB;
using SmemIteratorQ = typename MmaCoreQK::SmemIteratorA;
using SmemIteratorK = typename MmaCoreQK::SmemIteratorB;
using ThreadblockShapePV = typename Kernel_traits::ThreadblockShapePV;
using LayoutV = typename Kernel_traits::LayoutV;
using LayoutO = typename Kernel_traits::LayoutO;
using MmaCorePV = typename Kernel_traits::MmaCorePV;
using WarpMmaPV = typename MmaCorePV::MmaTensorOp;
using WarpIteratorV = typename WarpMmaPV::IteratorB;
using SmemLayoutV = typename MmaCorePV::SmemLayoutB;
using SmemIteratorV = typename MmaCorePV::SmemIteratorB;
constexpr int kIterationsPV = WarpMmaPV::Shape::kK / WarpMmaPV::InstructionShape::kK;
// The global memory tile to load Q.
// Copy from mma_piplined_testbed.h
using GmemIteratorQ = typename Kernel_traits::GmemIteratorQ;
// The global memory tile to load K.
using GmemIteratorK = typename Kernel_traits::GmemIteratorK;
// The global memory tile to load V.
using GmemIteratorV = typename Kernel_traits::GmemIteratorV;
// The global memory tile to store O.
using GmemIteratorO = typename fmha::FMHAEpilogue<MmaCorePV>::GmemIterator;
using GmemIteratorOAccum = typename fmha::FMHAEpilogue<MmaCorePV>::GmemIteratorAccum;
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
using Smem_softmax_lse = typename Kernel_traits::Smem_softmax_lse;
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
// Shared memory.
extern __shared__ char smem_[];
// The thread index.
const int tidx = threadIdx.x;
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
Gemm1 gemm_q_k(smem_);
// Allocate the global memory tile loader for S.
Gmem_tile_s gmem_s(params, binfo, tidx);
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
// Wind gmem tiles to the correct position.
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
const int begin_og = begin;
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
const int steps_og = steps;
steps -= begin - begin_og;
if (Return_softmax) { gmem_s.move(begin); }
gmem_softmax_lse.move(begin);
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
// The base pointer of smem_v;
char *smem_v_addr = &smem_[Gemm1::SMEM_OFFSET_V];
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
SmemLayoutQ layout_Q = SmemLayoutQ::packed({ThreadblockShapeQK::kM, ThreadblockShapeQK::kK});
SmemIteratorQ smem_q({reinterpret_cast<Element *>(smem_), layout_Q}, tidx);
SmemLayoutK layout_K = SmemLayoutK::packed({ThreadblockShapeQK::kK, ThreadblockShapeQK::kN});
SmemIteratorK smem_k({reinterpret_cast<Element *>(smem_ + Kernel_traits::BYTES_PER_SMEM_Q), layout_K}, tidx);
SmemLayoutV layout_V = SmemLayoutV::packed({ThreadblockShapePV::kK, ThreadblockShapePV::kN});
// SmemIterator stores to smem and WarpIterator loads from smem
SmemIteratorV smem_v({reinterpret_cast<Element *>(smem_v_addr), layout_V}, tidx);
WarpIteratorV iter_V({reinterpret_cast<Element *>(smem_v_addr), layout_V}, threadIdx.x % 32);
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
using Smem_O = fmha::FMHAEpilogue<MmaCorePV>;
Smem_O smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
// Allocate the global memory tile loader for Q.
// cutlass::transform::threadblock::PredicatedTileIterator deals with seqlen not divisible
// by 16 in a different way than we want. If the seqlen_q is 36, the first iteration would
// load 4 rows and the next two iterations would load 16 rows each. Instead we round the
// actual_seqlen_q to be multiple of 16, then change the mask in the last iteration, so
// that in this case we would load 16, 16, 4.
LayoutQ gmem_layout_Q(params.q_row_stride_in_elts);
typename GmemIteratorQ::Params gmem_Q_params(gmem_layout_Q);
const uint32_t row_offset_q = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.q_row_stride_in_elts + binfo.bidh * params.q_head_stride_in_elts;
const int actual_seqlen_q = binfo.actual_seqlen_q - begin * ThreadblockShapeQK::kM;
const int seqlen_q_remainder = actual_seqlen_q % ThreadblockShapeQK::kM;
const int extent_q = ((actual_seqlen_q <= ThreadblockShapeQK::kM) || (seqlen_q_remainder == 0)) ? actual_seqlen_q : actual_seqlen_q + ThreadblockShapeQK::kM - seqlen_q_remainder;
GmemIteratorQ gmem_q(gmem_Q_params,
reinterpret_cast<Element *>(params.q_ptr) + row_offset_q,
{extent_q, params.d},
tidx);
// Allocate the global memory tile loader for K.
LayoutK gmem_layout_K(params.k_row_stride_in_elts);
typename GmemIteratorK::Params gmem_K_params(gmem_layout_K);
const uint32_t row_offset_k = (binfo.sum_s_k + loop_step_idx * ThreadblockShapeQK::kN) * params.k_row_stride_in_elts + binfo.bidh * params.k_head_stride_in_elts;
const int extent_k = min(binfo.actual_seqlen_k - loop_step_idx * ThreadblockShapeQK::kN, ThreadblockShapeQK::kN);
GmemIteratorK gmem_k(gmem_K_params,
reinterpret_cast<Element *>(params.k_ptr) + row_offset_k,
{params.d, extent_k},
tidx);
// Allocate the global memory tile loader for V.
LayoutV gmem_layout_V(params.v_row_stride_in_elts);
typename GmemIteratorV::Params gmem_V_params(gmem_layout_V);
const uint32_t row_offset_v = (binfo.sum_s_k + loop_step_idx * ThreadblockShapePV::kK) * params.v_row_stride_in_elts + binfo.bidh * params.v_head_stride_in_elts;
// extent_v is the same as extent_k
GmemIteratorV gmem_v(gmem_V_params,
reinterpret_cast<Element *>(params.v_ptr) + row_offset_v,
{extent_k, params.d},
tidx);
// Allocate the global memory tile loader for O.
LayoutO gmem_layout_O(params.o_row_stride_in_elts);
typename GmemIteratorO::Params gmem_O_params(gmem_layout_O);
const uint32_t row_offset_o = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.o_row_stride_in_elts + binfo.bidh * params.o_head_stride_in_elts;
GmemIteratorO gmem_o(gmem_O_params,
reinterpret_cast<Element *>(params.o_ptr) + row_offset_o,
{actual_seqlen_q, params.d},
tidx);
typename GmemIteratorOAccum::Params gmem_Oaccum_params(gmem_layout_O);
GmemIteratorOAccum gmem_o_accum(gmem_Oaccum_params,
reinterpret_cast<ElementAccum *>(params.o_tmp_ptr) + row_offset_o,
{actual_seqlen_q, params.d},
tidx);
// Create the object to do the softmax.
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
Smem_softmax_lse smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]));
if (!Is_first) {
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
}
if (!Is_first) { __syncthreads(); }
// Trigger the loads for V.
typename GmemIteratorV::Fragment gmem_frag_v;
gmem_frag_v.clear();
gmem_v.load(gmem_frag_v);
// Trigger the loads for Q.
typename GmemIteratorQ::Fragment gmem_frag_q;
gmem_frag_q.clear();
gmem_q.load(gmem_frag_q);
// Trigger the loads for K.
typename GmemIteratorK::Fragment gmem_frag_k;
gmem_frag_k.clear();
gmem_k.load(gmem_frag_k);
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
// Commit the data for Q and V to shared memory.
smem_v.store(gmem_frag_v);
smem_q.store(gmem_frag_q);
// Commit the data for K to shared memory.
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
smem_k.store(gmem_frag_k);
}
__syncthreads();
// Load the fragments for Q.
gemm_q_k.load_q();
// Load the fragments for V. We keep the data in registers during the entire
// kernel. copied from mma_pipelined.h
const int warp_idx = threadIdx.x / 32;
iter_V.add_tile_offset({kIterationsPV * warp_idx, 0});
typename WarpIteratorV::Fragment frag_v[kIterationsPV];
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N );
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
iter_V.load(frag_v[ki]);
++iter_V;
}
// Commit the data for K to shared memory if it has not been done already.
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
// Make sure we are done loading the fragments for K.
__syncthreads();
// Commit the data to shared memory for K.
smem_k.store(gmem_frag_k);
// Make sure the data is in shared memory.
__syncthreads();
}
// Load the fragments for K.
gemm_q_k.load_k();
// Load over the entire sequence length.
for( int l = 0; l < steps; l++ ) {
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
// Declare the accumulators for the 1st gemm.
WarpMmaQK mma_qk;
typename WarpMmaQK::FragmentC acc_p;
acc_p.clear();
// Do this part of P = Q * K^T.
gemm_q_k(mma_qk, acc_p);
typename Smem_O::OutputFragment out[Smem_O::kIterationsStore];
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore);
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore);
if (!Is_first) {
#pragma unroll
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
gmem_o_accum.load(out[iter]);
gmem_o_accum.move();
}
}
// Trigger the load for the next Q values.
if( l < steps - 1) {
++gmem_q;
// If actual_seqlen_q is not a multiple of 16, we change the mask in the last iteration
// to load the "residue" tile.
if ((l + 1 == steps - 1) && (actual_seqlen_q % ThreadblockShapeQK::kM != 0)) {
// TODO: this probably only works for head_dim = 64 and head_dim = 128, which is
// what we have right now. Maybe for head_dim = 32 or 96, this could be different.
const int row_idx = tidx / (GmemIteratorQ::Shape::kColumn / GmemIteratorQ::Fragment::kElements);
if (row_idx >= actual_seqlen_q - (l + 1) * ThreadblockShapeQK::kM) {
gmem_q.clear_mask();
}
}
gmem_q.load(gmem_frag_q);
}
// Load the mask for that iteration.
mask.load(begin + l);
// Convert from the accumulator type to FP32 for Softmax.
softmax.unpack_noscale(acc_p);
// Apply the mask.
softmax.apply_mask(mask);
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
__syncthreads();
}
// Compute the max.
float p_max[Mma_tile_p::MMAS_M * 2];
if (!Is_first) {
smem_softmax_lse.store_pair(p_prev_lse);
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1; }
}
// Trigger the load for the next LSE values.
if( l < steps - 1) {
if (!Is_first) {
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
}
}
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
// Compute the exponential value.
softmax.scale_apply_exp(p_max, params.scale_bmm1);
// We don't finalize the sum reduction here, as that would incur an extra sync_threads().
// Instead, we reduce the sum from each warp, write to smem, then wait until the sync_threads()
// from storing acc_o. Then we read the sum of each warp from smem and finalize the reduction.
// As a consequence, we don't scale acc_p by the inverse sum, we scale the output by the inverse sum.
// Compute the sum.
float p_sum[Mma_tile_p::MMAS_M * 2];
// softmax.reduce_sum(p_sum);
softmax.reduce_sum_before_sync_(p_sum);
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
if (Is_dropout) {
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
}
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
softmax.pack_noconvert(acc_p);
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(acc_p)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_p;
auto frag_p = convert_p(acc_p);
if (Return_softmax) {
gmem_s.store(reinterpret_cast<const cutlass::Array<Element, 8>(&)[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]>(frag_p), mask);
gmem_s.move();
}
// Commit the values for Q into shared memory.
if (l < steps - 1) { smem_q.store(gmem_frag_q); }
if (Is_dropout && encode_dropout_in_sign_bit) {
cutlass::epilogue::thread::ReLu<decltype(frag_p)> relu;
frag_p = relu(frag_p);
}
// Declare the accumulators for the 2nd gemm.
WarpMmaPV mma_pv;
typename WarpMmaPV::FragmentC acc_o;
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8);
acc_o.clear();
// For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4).
// We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be
// an array of WarpMmaPV::FragmentA (which is what mma_pv expects).
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements);
const auto frag_p_reshaped = reinterpret_cast<const cutlass::Array<Element, WarpMmaPV::FragmentA::kElements> (&)[kIterationsPV]>(frag_p);
#pragma unroll
for( int ki = 0; ki < kIterationsPV; ++ki ) {
mma_pv(acc_o, reinterpret_cast<const typename WarpMmaPV::FragmentA(&)>(frag_p_reshaped[ki]), frag_v[ki], acc_o);
}
// Swizzle the elements and do the final reduction.
smem_o.store(acc_o);
// The mapping from tidx to rows changes between the softmax and the
// O-reduction. So we recalculate the max.
using OutputTileThreadMap = typename Smem_O::OutputTileThreadMap;
constexpr int kOutputRowsPerThread = OutputTileThreadMap::Iterations::kRow * Smem_O::kIterationsStore;
float p_max_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
int rows[kOutputRowsPerThread];
cutlass::MatrixCoord output_thread_offset = OutputTileThreadMap::initial_offset(tidx);
const int output_thread_start_row = output_thread_offset.row();
const int output_thread_start_column = output_thread_offset.column();
for (int iter = 0; iter < Smem_O::kIterationsStore; ++iter) {
for (int row = 0; row < OutputTileThreadMap::Iterations::kRow; ++row) {
rows[iter * OutputTileThreadMap::Iterations::kRow + row] = output_thread_start_row + iter * OutputTileThreadMap::Shape::kRow + row;
}
}
softmax.reduce_max_after_sync_(p_max_o, rows);
static_assert(Mma_tile_o::MMAS_M == 1);
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
p_max_o[jj][0] *= params.scale_bmm1;
}
float p_prev_scale_o[kOutputRowsPerThread];
if (!Is_first) {
smem_softmax_lse.load(p_prev_scale_o, rows);
}
// Make sure the data is in shared memory.
__syncthreads();
static_assert(Mma_tile_o::MMAS_M == 1);
float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
softmax.reduce_sum_after_sync_(p_sum_o, rows);
if (!Is_first) {
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
p_sum_o[jj][0] += p_prev_scale_o[jj];
}
}
float p_sum_log[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
#pragma unroll
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
float sum = p_sum_o[jj][0];
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
if (output_thread_start_column == 0) {
gmem_softmax_lse.store_row(
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
}
}
gmem_softmax_lse.move();
// Load from shared memory.
using ArrayTypeO = cutlass::Array<ElementAccum, OutputTileThreadMap::kElementsPerAccess>;
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements);
cutlass::multiplies<ArrayTypeO> multiply_fragments;
if (!Is_first) {
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
out_reshaped[jj] = multiply_fragments(out_reshaped[jj], p_prev_scale_o[jj]);
}
}
smem_o.template load</*zero_init=*/Is_first>(out, tidx);
const bool is_final_write =
Is_last
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
#pragma unroll
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
float sum = p_sum_o[jj][0];
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
if (Is_dropout && is_final_write) {
inv_sum *= params.rp_dropout;
}
out_reshaped[jj] = multiply_fragments(out_reshaped[jj], inv_sum);
}
// Output the values.
if (is_final_write) {
typename GmemIteratorO::Fragment out_converted;
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(out_converted)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_o;
#pragma unroll
for (int iter = 0; iter < GmemIteratorO::kIterations; ++iter) {
out_converted = convert_o(out[iter]);
gmem_o.store(out_converted);
gmem_o.move();
}
// We also need to move gmem_o_accum. For example, if Is_causal=true and seqlen=512,
// in the first loop, we write the first 256 rows to gmem_o and the last 256 rows to gmem_o_accum.
if (Is_first && !Is_last) { gmem_o_accum.move(GmemIteratorOAccum::kIterations); }
} else {
if (!Is_first) { gmem_o_accum.move(-GmemIteratorOAccum::kIterations); }
#pragma unroll
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
gmem_o_accum.store(out[iter]);
gmem_o_accum.move();
}
}
gemm_q_k.reload_k();
// Trigger the load from shared memory for the next series of Q values.
if(l < steps - 1) {
gemm_q_k.reload_q();
}
} // Outer loop over the sequence length.
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
inline __device__ void device_1xN_loop(const Params &params) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The thread index.
const int tidx = threadIdx.x;
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
auto seeds = at::cuda::philox::unpack(params.philox_args);
// We use 2 Philox generators to match the dropout pattern in the backward pass.
// Forward pass uses 128 threads while backward pass uses 256 threads, so each thread
// in the forward pass is simulating the droout pattern of 2 threads in the backward pass.
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
constexpr int M = Kernel_traits::Cta_tile_p::M;
const int STEPS = (params.seqlen_q + M - 1) / M;
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
if (params.seqlen_k == blocksize_c) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
} else {
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx);
}
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,134 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
#include <ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
__global__ void fmha_fprop_loop_kernel(FMHA_fprop_params params) {
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
}
template<typename Kernel_traits>
void run_fmha_loop_(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
if (configure) {
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
constexpr int M = Kernel_traits::Cta_tile_p::M;
size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
launch_params.elts_per_thread = elts_per_head;
return;
}
constexpr size_t smem_size_softmax_lse = Kernel_traits::Smem_softmax_lse::BYTES_PER_TILE;
// Don't need smem_size_softmax_lse if we're not looping
const size_t smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
// printf("smem_size = %d\n", smem_size);
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
// https://github.com/kokkos/kokkos-kernels/issues/349
// https://github.com/HazyResearch/flash-attention/issues/21
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
auto kernel = launch_params.params.is_causal
? (launch_params.return_softmax
? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
: &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
: (launch_params.return_softmax
? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
: &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
// constexpr bool IsDropoutConstTmp = false;
// auto kernel = launch_params.params.is_causal
// ? (launch_params.return_softmax
// ? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, true, true>
// : &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, true, false>)
// : (launch_params.return_softmax
// ? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, false, true>
// : &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, false, false>);
if( smem_size >= 48L * 1024 ) {
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
}
dim3 grid(launch_params.params.b, launch_params.params.h);
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
launch_params.params);
FMHA_CHECK_CUDA(cudaPeekAtLastError());
});
}
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
const bool configure) {
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
using elem_type = std::conditional<IsBf16Const, cutlass::bfloat16_t, cutlass::half_t>::type;
auto dprops = at::cuda::getCurrentDeviceProperties();
if (launch_params.params.d <= 64) {
if( launch_params.params.seqlen_k == 128 ) {
// TD [2022-08-20]: One might expect that not sharing the smem between K & V
// could be faster, but seems like it's the same speed.
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
} else if( launch_params.params.seqlen_k >= 256 ) {
if (dprops->major == 8 && dprops->minor >= 0) {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
} else if (dprops->major == 7 && dprops->minor == 5) {
if (launch_params.is_dropout) { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
} else {
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
}
}
}
} else if (launch_params.params.d <= 128) {
if( launch_params.params.seqlen_k == 128 ) {
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
} else {
if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) {
// TD [2022-06-05] Keep K in smem to reduce register spilling
// Gives about 6% speedup compared to using block size 128.
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
} else { // Need to use the same block size as backward
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
run_fmha_loop_<Kernel_traits>(launch_params, configure);
}
}
}
});
}

View File

@ -0,0 +1,71 @@
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS_PER_CTA>
struct BlockInfoPadded {
template<typename Params>
__device__ BlockInfoPadded(const Params &params,
const int bidb,
const int bidh,
const int tidx)
: bidb(bidb), bidh(bidh), h(params.h) {
// The block index.
sum_s_k = params.cu_seqlens_k[bidb];
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
sum_s_q = params.cu_seqlens_q[bidb];
actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
}
__device__ bool stop_early(const int start_col = 0) const {
return actual_seqlen_k <= start_col;
}
uint32_t actual_seqlen_q;
uint32_t actual_seqlen_k;
uint32_t sum_s_q;
uint32_t sum_s_k;
uint32_t bidh;
uint32_t bidb;
uint32_t tidx_global;
uint32_t h;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,52 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cassert>
#include <cstdio>
#include <cstdlib>
#include <cuda_runtime_api.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define FMHA_CHECK_CUDA( call ) \
do { \
cudaError_t status_ = call; \
if( status_ != cudaSuccess ) { \
fprintf( stderr, \
"CUDA error (%s:%d): %s\n", \
__FILE__, \
__LINE__, \
cudaGetErrorString( status_ ) ); \
exit( 1 ); \
} \
} while( 0 )
////////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,95 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
#include <third_party/cutlass/include/cutlass/layout/layout.h>
#include <third_party/cutlass/include/cutlass/arch/mma.h>
#include <third_party/cutlass/include/cutlass/array.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The number of rows in the CTA tile.
int M_,
// The number of cols in the CTA tile.
int N_,
// The number of elements in the the K dimension of the GEMM loop.
int K_,
// The number of rows of warps.
int WARPS_M_,
// The number of cols of warps.
int WARPS_N_,
// The number of warps in the K dimension of the GEMM loop.
int WARPS_K_>
struct Cta_tile_ {
static constexpr int M = M_, N = N_, K = K_;
// The number of warps.
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
// The number of warps per CTA.
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
// The number of threads per warp.
static constexpr int THREADS_PER_WARP = 32;
// The number of threads per CTA.
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile>
struct Hmma_tile {
// The number of elements computed with a single warp-MMA.
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
// The number of elements computed with a single CTA-MMA.
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
// The number of MMAs needed to compute the GEMM.
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,272 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, int BYTES_PER_ELEMENT >
struct Gmem_tile_mma_sd {
// The mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// Each STG stores 8 elements.
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
// The number of MMAs in the M dimension.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The number of MMAs in the N dimension.
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of rows computed per MMA per thread block.
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
// The number of cols computed per MMA per thread block.
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
// The number of threads per block.
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
// The size of each row in bytes. I.e. how many bytes are stored per STG.
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
// The distance between elements stored per loop (in bytes).
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
// The type of elements stored per STG.
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
// Ctor.
template<typename Params>
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params &params, const int bidb, const int bidh, const int tidx)
: ptr_(static_cast<char *>(ptr)) {
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// The distance between two blocks (in bytes).
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
}
// Store to global memory.
inline __device__ void store(const Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::stg(ptr_ + offset, data);
}
// Load from global memory.
inline __device__ void load(Type &data, const int mi, const int ni) {
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
fmha::ldg(data, ptr_ + offset);
}
// Move to the next tile.
inline __device__ void move(const int steps = 1) {
ptr_ += LOOP_STRIDE_BYTES * steps;
}
// The pointer in global memory.
char *ptr_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
struct Gmem_tile_mma_s : public Base {
// The number of mmas in the vertical dimension.
static constexpr int M = Base::MMAS_M;
// The number of mmas in the horizontal dimension.
static constexpr int N = Base::MMAS_N;
// The type of the vectors stored by each STG.
using Type = typename Base::Type;
// Ctor.
template< typename Params, typename Block_info >
inline __device__ Gmem_tile_mma_s(const Params &params, const Block_info& binfo, const int tidx)
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
}
// Store to global memory.
template<typename Mask, typename Fragment>
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
static_assert(Fragment::kStorageElements == 4);
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
uint4 dst;
dst.x = frag[ni][mi].raw_data()[0];
dst.y = frag[ni][mi].raw_data()[2];
dst.z = frag[ni][mi].raw_data()[1];
dst.w = frag[ni][mi].raw_data()[3];
if( mask.any_valid(mi, ni) ) {
Base::store(dst, mi, ni);
}
}
}
}
// Load from global memory.
template<typename Mask>
inline __device__ void load(uint4 (&regs)[M][N], const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
regs[mi][ni] = make_uint4(0, 0, 0, 0);
if( mask.any_valid(mi, ni) ) {
Base::load(regs[mi][ni], mi, ni);
}
}
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<
// The dimensions of the tile computed by the CTA.
typename Cta_tile
>
struct Gmem_summary_stats {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
// The size of each element.
static constexpr int BYTES_PER_ELEMENT = 4;
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
static constexpr int ROWS = Cta_tile::M;
// Ctor.
template<typename Params>
inline __device__ Gmem_summary_stats(void *ptr, const Params &params, const int tidx)
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
// The block index for the batch.
const int bidb = blockIdx.x;
// The block index for the head.
const int bidh = blockIdx.y;
// The block index.
// size_t bidx = bidb * params.h + bidh;
uint32_t bidx = bidb * params.h + bidh;
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// The distance between two blocks (in bytes).
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
// Set store location for each thread at the beginning of the loop
ptr_row_ = ptr_ + bidx * block_stride_bytes;
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
}
// Store data to global memory.
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
if ((warp == 0) && (lane % 4 == 0)) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
}
}
}
// Store data to global memory.
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
}
}
// Load from global memory.
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Load from global memory.
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
#pragma unroll
for (int mi = 0; mi < MMAS_M; ++mi) {
// TODO: Not sure if it's right for MMAS_M > 1
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
}
}
// Store data to global memory.
template <int N>
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
#pragma unroll
for (int ni = 0; ni < N; ++ni) {
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
}
}
// Move the pointer to the next location.
inline __device__ void move() {
ptr_ += ROWS * BYTES_PER_ELEMENT;
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
}
// Move the pointer to the next location.
inline __device__ void move(const int steps) {
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
}
// The pointer.
char *ptr_;
char *ptr_row_;
const int tidx_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,154 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/gemm/gemm.h>
#include <third_party/cutlass/include/cutlass/layout/layout.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
#include <third_party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h>
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
#include <ATen/native/transformers/cuda/flash_attn/gmem_tile.h>
#include <ATen/native/transformers/cuda/flash_attn/summary_stats.h>
#include <ATen/native/transformers/cuda/flash_attn/mma_core_sm75.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type=cutlass::half_t>
struct FMHA_kernel_traits {
// The CTA description for the 1st GEMM.
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
// The CTA description for the 2nd GEMM.
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
// Do we use one buffer for K and V.
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
// Do we keep K in registers.
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
// Do we keep V in registers.
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
// The global memory tile to load/store S.
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
// The global memory tile to store the softmax sum.
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
// The number of threads.
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
// Make sure the number of threads matches both CTAs.
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
#else
// using MmaInstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
// TD [2022-06-02] We don't support Volta (SM70) yet.
#endif
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
using Element = elem_type;
#else
using Element = cutlass::half_t;
#endif
using ElementAccum = float;
static_assert(WARPS_M == 1);
using ThreadblockShapeQK = cutlass::gemm::GemmShape<STEP, S, D>;
using WarpCountQK = cutlass::gemm::GemmShape<WARPS_M, WARPS_N, 1>;
using WarpShapeQK = cutlass::gemm::GemmShape<
ThreadblockShapeQK::kM,
ThreadblockShapeQK::kN / WarpCountQK::kN, ThreadblockShapeQK::kK>;
using LayoutQ = cutlass::layout::RowMajor;
using LayoutK = cutlass::layout::ColumnMajor;
using LayoutP = cutlass::layout::RowMajor;
using MmaCoreQK = typename fmha::FMHAMmaCore<
ThreadblockShapeQK, WarpShapeQK, MmaInstructionShape, Element, LayoutQ,
Element, LayoutK, ElementAccum, LayoutP,
cutlass::arch::OpClassTensorOp>;
using ThreadblockShapePV = cutlass::gemm::GemmShape<STEP, D, S>;
using WarpCountPV = cutlass::gemm::GemmShape<WARPS_M, 1, WARPS_N>;
using WarpShapePV = cutlass::gemm::GemmShape<ThreadblockShapePV::kM, ThreadblockShapePV::kN, ThreadblockShapePV::kK / WarpCountPV::kK>;
using LayoutV = cutlass::layout::RowMajor;
using LayoutO = cutlass::layout::RowMajor;
using MmaCorePV = typename fmha::FMHAMmaCore<
ThreadblockShapePV, WarpShapePV, MmaInstructionShape, Element, LayoutP,
Element, LayoutV, ElementAccum, LayoutO,
cutlass::arch::OpClassTensorOp>;
// The global memory tile to load Q.
// Copy from mma_piplined_testbed.h
using GmemIteratorQ = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<ThreadblockShapeQK::kM, ThreadblockShapeQK::kK>,
Element,
LayoutQ,
0,
typename MmaCoreQK::IteratorThreadMapA
>;
// The global memory tile to load K.
using GmemIteratorK = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<ThreadblockShapeQK::kK, ThreadblockShapeQK::kN>,
Element,
LayoutK,
1,
typename MmaCoreQK::IteratorThreadMapB
>;
// The global memory tile to load V.
using GmemIteratorV = cutlass::transform::threadblock::PredicatedTileIterator<
cutlass::MatrixShape<ThreadblockShapePV::kK, ThreadblockShapePV::kN>,
Element,
LayoutV,
0,
typename MmaCorePV::IteratorThreadMapB
>;
// The shared memory tile to store softmax lse.
using Smem_softmax_lse = fmha::Smem_tile_softmax_lse<ThreadblockShapeQK::kM, MmaInstructionShape::kM, WarpCountQK::kM>;
// The amount of shared memory needed to load Q and K.
static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element);
static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element);
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V);
static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K;
// The extra amount of shared memory needed to load V.
static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V;
// The amount of shared memory needed for Q, K and V..
static constexpr size_t BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V_EXTRA;
};

View File

@ -0,0 +1,92 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <ATen/cuda/CUDAContext.h>
namespace fmha {
template<typename Cta_tile, bool Is_causal=false>
struct Mask {
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
template<typename BInfo>
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
, loop_step_idx(loop_step_idx_) {
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
static_assert(Cta_tile::WARPS_K == 1, "");
// find the warp in the Cta tile
const int warp_n = (warp / Cta_tile::WARPS_M);
const int warp_m = (warp % Cta_tile::WARPS_M);
// decompose warp into 8x4 tile
const int quad = lane / 4;
const int tid = (lane % 4) * 2;
row = warp_m * 16 + quad;
// col = warp_n * 16 + tid;
col = warp_n * Mma_tile::N_PER_MMA * Mma_tile::MMAS_N + tid;
}
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
// ii and jj iterate over the 2x4 fragment
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
// const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
const int current_col = ni * Mma_tile::N_PER_MMA + col + (jj & 2) * 4 + (jj & 1);
const int current_row = row_offset + ii * 8;
const bool col_valid = current_col < actual_seqlen_k;
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
// }
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
// return row_valid && col_valid;
}
//BERT Mask: if upper left is invalid, none are valid
inline __device__ bool any_valid(const int mi, const int ni) const {
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
}
inline __device__ void load(const int it) {
row_offset = it * Cta_tile::M + row;
}
int row_offset;
int row;
int col;
const int loop_step_idx;
const int actual_seqlen_k;
};
} // namespace fmha

View File

@ -0,0 +1,382 @@
// Adapted from cutlass/gemm/threadblock/default_mma_core_sm75.h
// This is very similar, except we make it work for head_dim=128.
// The original cutlass version only allows kK of the thread block to be
// at most 64. Here we set kCrosswise = max(64, ThreadblockShape::kK) instead.
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/array.h>
#include <third_party/cutlass/include/cutlass/platform/platform.h>
#include <third_party/cutlass/include/cutlass/numeric_types.h>
#include <third_party/cutlass/include/cutlass/matrix_shape.h>
#include <third_party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h>
#include <third_party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h>
#include <third_party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h>
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Template defininng default matrix multiply operators inferred from threadblock tile size,
/// global memory data layout, and target math instruction.
template <
/// Shape of threadblock-scoped matrix multiply operator
typename Shape,
/// Shape of warp-level matrix multiply operator
typename WarpShape,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape,
/// Element data type of A operand
typename ElementA,
/// Layout of operand A
typename LayoutA,
/// Element data type of B operand
typename ElementB,
/// Layout of operand B
typename LayoutB,
/// Data type of accumulator
typename ElementC,
/// Layout of accumulator
typename LayoutC,
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
typename OperatorClass,
/// Operation performed by MMA
typename Operator = cutlass::arch::OpMultiplyAdd
>
struct FMHAMmaCore;
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: column-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Operation performed by MMA
typename Operator_>
struct FMHAMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
cutlass::layout::RowMajor, ElementB_, cutlass::layout::ColumnMajor,
ElementC_, LayoutC_, cutlass::arch::OpClassTensorOp, Operator_
> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = cutlass::arch::OpClassTensorOp;
/// Number of warps present
using WarpCount = cutlass::gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK
>;
// Divisibility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = cutlass::gemm::warp::WarpSize<cutlass::arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Cutlass only supports Crosswise at most 64
static int const kCrosswise = std::min(Shape::kK, 64);
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
static int const kWarpThreadArrangementContiguousB =
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value);
static int const kWarpThreadArrangementStridedB =
kWarpSize / kWarpThreadArrangementContiguousB;
//
// Shared memory layouts
//
using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, kCrosswise>;
// Shared memory layout
using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, kCrosswise>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator<
cutlass::MatrixShape<Shape::kM, Shape::kK>,
ElementA,
SmemLayoutA,
0,
IteratorThreadMapA
>;
/// ThreadMap of iterator B
using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
kWarpThreadArrangementStridedB>,
kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value>;
/// Shared memory iterator to B operand
using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator<
cutlass::MatrixShape<Shape::kK, Shape::kN>,
ElementB,
SmemLayoutB,
1,
IteratorThreadMapB
>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
MmaTensorOp,
cutlass::MatrixShape<0, 0>,
cutlass::MatrixShape<0, 0>,
WarpCount::kK
>;
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization:
///
/// A: row-major
/// B: row-major
/// Operator: tensor op class
///
/// This uses the default warp-level operator given tile sizes
template <
/// Shape of threadblock-scoped matrix multiply operator (concept:
/// GemmShape)
typename Shape_,
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
typename WarpShape_,
/// Shape of one matrix production operation (concept: GemmShape)
typename InstructionShape_,
/// Data type of A operand
typename ElementA_,
/// Data type of B operand
typename ElementB_,
/// Data type of accumulator
typename ElementC_,
/// Layout of accumulator
typename LayoutC_,
/// Operation performed by MMA
typename Operator_>
struct FMHAMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
cutlass::layout::RowMajor, ElementB_, cutlass::layout::RowMajor, ElementC_,
LayoutC_, cutlass::arch::OpClassTensorOp, Operator_
> {
using Shape = Shape_;
using WarpShape = WarpShape_;
using InstructionShape = InstructionShape_;
using ElementA = ElementA_;
using LayoutA = cutlass::layout::RowMajor;
using ElementB = ElementB_;
using LayoutB = cutlass::layout::ColumnMajor;
using ElementC = ElementC_;
using LayoutC = LayoutC_;
using OperatorClass = cutlass::arch::OpClassTensorOp;
/// Number of warps present
using WarpCount = cutlass::gemm::GemmShape<
Shape::kM / WarpShape::kM,
Shape::kN / WarpShape::kN,
Shape::kK / WarpShape::kK
>;
// Divisility requirements
static_assert(
!(Shape::kM % WarpShape::kM) &&
!(Shape::kN % WarpShape::kN),
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
);
/// Number of threads per warp
static int const kWarpSize = cutlass::gemm::warp::WarpSize<cutlass::arch::OpClassTensorOp>::value;
/// Number of threads total
static int const kThreads = WarpCount::kCount * kWarpSize;
/// Size of a threadblock-scoped access
static int const kAccessSizeInBits = 128;
/// Cutlass only supports Crosswise at most 64
static int const kCrosswise = std::min(Shape::kK, 64);
/// Default Operator
using Operator = Operator_;
// Warp thread arrangement
static int const kWarpThreadArrangementContiguousA =
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value);
static int const kWarpThreadArrangementStridedA =
kWarpSize / kWarpThreadArrangementContiguousA;
//
// Shared memory layouts
//
using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, kCrosswise>;
// Shared memory layout
using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous<
cutlass::sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
//
// Iterators to write to shared memory
//
/// ThreadMap of iterator A
using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
kWarpThreadArrangementStridedA>,
kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value>;
/// Shared memory iterator to A operand
using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator<
cutlass::MatrixShape<Shape::kM, Shape::kK>,
ElementA,
SmemLayoutA,
0,
IteratorThreadMapA
>;
/// ThreadMap of iterator B
using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap<
cutlass::layout::PitchLinearShape<Shape::kN, Shape::kK>,
kThreads,
cutlass::layout::PitchLinearShape<8, 4>,
kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value
>;
/// Shared memory iterator to B operand
using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator<
cutlass::MatrixShape<Shape::kK, Shape::kN>,
ElementB,
SmemLayoutB,
0,
IteratorThreadMapB
>;
//
// Warp-level matrix multiply operator
//
// Define the warp-level tensor op
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
/// Policy used to define MmaPipelined
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
MmaTensorOp,
cutlass::MatrixShape<0, 0>,
cutlass::MatrixShape<0, 0>,
WarpCount::kK
>;
};
} // namespace fmha

View File

@ -0,0 +1,146 @@
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
#pragma once
// Philox CUDA.
#include <ATen/cuda/CUDAContext.h>
namespace {
class Philox {
public:
__device__ inline Philox(unsigned long long seed,
unsigned long long subsequence,
unsigned long long offset)
: STATE(0)
, key(reinterpret_cast<const uint2&>(seed)) {
//key.x = (unsigned int)seed;
//key.y = (unsigned int)(seed >> 32);
//counter = make_uint4(0, 0, 0, 0);
//counter.z = (unsigned int)(subsequence);
//counter.w = (unsigned int)(subsequence >> 32);
//STATE = 0;
//incr_n(offset / 4);
// key = reinterpret_cast<const uint2&>(seed);
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
tmp->x = offset / 4;
tmp->y = subsequence;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ inline uint4 operator()() {
// if (STATE == 0) {
uint4 counter_ = counter;
uint2 key_ = key;
// 7-round philox
#pragma unroll
for (int i = 0; i < 6; i++) {
counter_ = single_round(counter_, key_);
key_.x += (kPhilox10A);
key_.y += (kPhilox10B);
}
// output = single_round(counter_, key_);
uint4 output = single_round(counter_, key_);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
// }
incr();
// }
// return a float4 directly
// unsigned long ret;
// switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
// STATE = (STATE + 1) % 4;
return output;
}
private:
struct ull2 {
uint64_t x;
uint64_t y;
};
uint4 counter;
// uint4 output;
const uint2 key;
unsigned int STATE;
__device__ inline void incr_n(unsigned long long n) {
unsigned int nlo = (unsigned int)(n);
unsigned int nhi = (unsigned int)(n >> 32);
counter.x += nlo;
if (counter.x < nlo)
nhi++;
counter.y += nhi;
if (nhi <= counter.y)
return;
if (++counter.z)
return;
++counter.w;
}
__device__ uint4 incr128 (uint4 ctr)
{
uint4 res;
asm ("add.cc.u32 %0, %4, %8;\n\t"
"addc.cc.u32 %1, %5, %9;\n\t"
"addc.cc.u32 %2, %6, %10;\n\t"
"addc.u32 %3, %7, %11;\n\t"
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
"n"(1), "n"(0), "n"(0), "n"(0));
return res;
}
__device__ inline void incr() {
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
counter = incr128(counter);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
// }
}
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
unsigned int *result_high) {
*result_high = __umulhi(a, b);
return a * b;
}
__device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b)
{
uint2 *res;
unsigned long long tmp;
asm ("mul.wide.u32 %0, %1, %2;\n\t"
: "=l"(tmp)
: "r"(a), "r"(b));
res = (uint2*)(&tmp);
return *res;
}
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
//unsigned int hi0;
//unsigned int hi1;
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x);
uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z);
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
return ret;
}
static const unsigned long kPhilox10A = 0x9E3779B9;
static const unsigned long kPhilox10B = 0xBB67AE85;
static const unsigned long kPhiloxSA = 0xD2511F53;
static const unsigned long kPhiloxSB = 0xCD9E8D57;
};
// Inverse of 2^32.
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
__device__ __inline__ float4 uniform4(const uint4 x) {
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
x.w * M_RAN_INVM32);
}
} // namespace

View File

@ -0,0 +1,446 @@
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cmath>
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
#include <third_party/cutlass/include/cutlass/cutlass.h>
#include <third_party/cutlass/include/cutlass/array.h>
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp_(float x, float max) {
return __expf(x - max);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ float apply_exp2_(float x, float max) {
return exp2f(x - max);
// With fast-math, this produces the same PTX instruction as the assembly below
// float diff = x - max;
// float res;
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
// return res;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int COLS> struct ReadType {};
template<> struct ReadType<4> { using T = float;};
template<> struct ReadType<8> { using T = float2;};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename Cta_tile, typename Kernel_traits>
struct Smem_tile_reduce {
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
static constexpr int COLS = WARPS_N;
static_assert(COLS == 4 || COLS == 8);
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
static constexpr int ELTS_PER_TILE = ROWS * COLS;
using read_t = typename ReadType<COLS>::T;
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
int lane = tidx % 32;
int warp = tidx / 32;
int warp_m = warp % WARPS_M;
int warp_n = warp / WARPS_M;
qid_ = lane % 4;
int qp = lane / 4;
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
// This won't affect reading as we assume commutative reduction ops.
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
}
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
if( qid_ == 0 ) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * WARPS_N;
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
}
}
}
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
}
}
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
int offset = mi * 16 * 4;
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
}
}
int qid_;
float *smem_write_;
read_t *smem_read_;
read_t *smem_read_row_;
};
template<typename Cta_tile, typename Kernel_traits>
struct Softmax_base {
// The Mma tile.
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
// The number of MMAs in M/N dimensions.
static constexpr int MMAS_M = Mma_tile::MMAS_M;
static constexpr int MMAS_N = Mma_tile::MMAS_N;
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
// The number of elements that we are going to store per row.
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
// The number of rows.
static constexpr int ROWS = Cta_tile::M * GROUPS;
// The total number of elements.
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
// Ctor.
template<typename Params>
inline __device__ Softmax_base(const Params &params, void *smem, int tidx)
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
// Extract the position in the warp.
int warp = tidx / Cta_tile::THREADS_PER_WARP;
int lane = tidx % Cta_tile::THREADS_PER_WARP;
// Decompose the warp index into M and N.
int warp_m = warp % Cta_tile::WARPS_M;
int warp_n = warp / Cta_tile::WARPS_M;
// Decompose the warp-n index into group/position-inside-the-group.
int warp_g = warp_n / ELEMENTS_PER_ROW;
int warp_i = warp_n % ELEMENTS_PER_ROW;
// The location written by the threads.
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
int write_col = warp_i;
// Assemble the write pointer.
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
// Assemble the read pointer.
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
}
template<bool zero=false, typename Mask>
inline __device__ void apply_mask(const Mask &mask) {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ii = 0; ii < 2; ++ii ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
#pragma unroll
for( int jj = 0; jj < 4; ++jj ) {
if( !mask.is_valid(mi, ni, ii, jj) ) {
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
}
}
}
}
}
}
// Apply the exp to all the elements.
template <bool scale_max=true>
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
const float scale = scale_ * M_LOG2E;
#pragma unroll
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
const float max_scaled = max[mi] * max_scale;
#pragma unroll
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni++ ) {
uint4 random_uint4 = ph();
uint16_t (&rnd)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd.x, rnd.y, rnd.z, rnd.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(rnd[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
}
}
}
template <bool encode_dropout_in_sign_bit=false>
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
// We encode the dropout pattern in the sign bit of the non-negative
// softmax to distinguish from pre-existing zeros
auto encode_dropout = [](bool keep, float val) {
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
};
#pragma unroll
for( int mi = 0; mi < MMAS_M; mi++ ) {
static_assert(MMAS_N % 2 == 0);
#pragma unroll
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
uint4 random_uint4 = ph0();
uint16_t (&rnd0)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd0.x, rnd0.y, rnd0.z, rnd0.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * ni + jj] =
encode_dropout(rnd0[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
}
}
random_uint4 = ph1();
uint16_t (&rnd1)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd1.x, rnd1.y, rnd1.z, rnd1.w);
// }
#pragma unroll
for (int ii = 0; ii < 2; ++ii) {
#pragma unroll
for (int jj = 0; jj < 4; ++jj) {
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
encode_dropout(rnd1[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
}
}
}
}
}
// Shared memory for the CTA-wide reduction.
float *smem_, *smem_write_, *smem_read_;
// The current thread index.
int tidx_;
// The elements.
float elt_[MMAS_M * 2][MMAS_N * 4];
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Cta_tile, typename Kernel_traits>
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
// The base class.
using Base = Softmax_base<Cta_tile, Kernel_traits>;
static constexpr int WARPS_M = Cta_tile::WARPS_M;
static constexpr int WARPS_N = Cta_tile::WARPS_N;
// The MMAs.
static constexpr int MMAS_M = Base::MMAS_M;
static constexpr int MMAS_N = Base::MMAS_N;
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
// Ctor.
template<typename Params>
inline __device__ Softmax(const Params &params, void *smem, int tidx)
: Base(params, smem, tidx)
, smem_sum_(static_cast<float*>(smem), tidx)
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
}
// Pack the data to a fragment for the next GEMM.
inline __device__ void pack_noconvert(cutlass::Array<float, MMAS_M * MMAS_N * 8> &frag) const {
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ki = 0; ki < MMAS_N; ++ki ) {
// 1st row - 4 elements per row.
frag[ki * MMAS_M * 8 + mi * 8 + 0] = this->elt_[2 * mi + 0][4 * ki + 0];
frag[ki * MMAS_M * 8 + mi * 8 + 1] = this->elt_[2 * mi + 0][4 * ki + 1];
frag[ki * MMAS_M * 8 + mi * 8 + 4] = this->elt_[2 * mi + 0][4 * ki + 2];
frag[ki * MMAS_M * 8 + mi * 8 + 5] = this->elt_[2 * mi + 0][4 * ki + 3];
// 2nd row - 4 elements per row.
frag[ki * MMAS_M * 8 + mi * 8 + 2] = this->elt_[2 * mi + 1][4 * ki + 0];
frag[ki * MMAS_M * 8 + mi * 8 + 3] = this->elt_[2 * mi + 1][4 * ki + 1];
frag[ki * MMAS_M * 8 + mi * 8 + 6] = this->elt_[2 * mi + 1][4 * ki + 2];
frag[ki * MMAS_M * 8 + mi * 8 + 7] = this->elt_[2 * mi + 1][4 * ki + 3];
}
}
}
template <typename FragmentC>
inline __device__ void unpack_noscale(const FragmentC (&acc)) {
static_assert(FragmentC::kElements == MMAS_M * MMAS_N * 8, "");
#pragma unroll
for( int mi = 0; mi < MMAS_M; ++mi ) {
#pragma unroll
for( int ni = 0; ni < MMAS_N; ++ni ) {
// 1st row - 4 elements per row.
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 0];
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 1];
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 4];
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 5];
// 2nd row - 4 elements per row.
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 2];
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 3];
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 6];
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 7];
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
#pragma unroll
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
#pragma unroll
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
}
}
}
template<bool zero_init=true, typename Operator>
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
thread_reduce_<zero_init>(frag, op);
quad_reduce(frag, frag, op);
smem_red.store(frag);
__syncthreads();
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
smem_red.load(tmp);
quad_allreduce(frag, tmp, op);
}
template<bool zero_init=true>
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
MaxOp<float> max;
reduce_<zero_init>(frag, max, smem_max_);
}
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
reduce_(frag, sum, smem_sum_);
}
template<bool zero_init=true>
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
SumOp<float> sum;
thread_reduce_<zero_init>(frag, sum);
quad_reduce(frag, frag, sum);
smem_sum_.store(frag);
}
template<int NROWS, typename Operator>
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS],
Operator &op, Smem_tile_red & smem_red) {
#pragma unroll
for (int ii = 0; ii < NROWS; ii++) {
typename Smem_tile_red::read_t tmp[MMAS_M];
smem_red.load_row(tmp, rows[ii]);
quad_allreduce(frag[ii], tmp, op);
}
}
template<int NROWS>
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
SumOp<float> sum;
reduce_after_sync_(frag, rows, sum, smem_sum_);
}
template<int NROWS>
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
const int (&rows)[NROWS]){
MaxOp<float> max;
reduce_after_sync_(frag, rows, max, smem_max_);
}
Smem_tile_red smem_max_;
Smem_tile_red smem_sum_;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -0,0 +1,25 @@
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
#pragma once
/// @param COND - a boolean expression to switch by
/// @param CONST_NAME - a name given for the constexpr bool variable.
/// @param ... - code to execute for true and false
///
/// Usage:
/// ```
/// BOOL_SWITCH(flag, BoolConst, [&] {
/// some_function<BoolConst>(...);
/// });
/// ```
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()

View File

@ -0,0 +1,55 @@
/******************************************************************************
* Copyright (c) 2022, Tri Dao.
******************************************************************************/
#pragma once
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int kRows, int kRowsPerMma, int kWarpCountM>
struct Smem_tile_softmax_lse {
static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma;
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows);
// static_assert(kWarpCountM == 1);
// Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0
// The size of one buffer in bytes in shared memory.
static constexpr size_t BYTES_PER_TILE = kRows * sizeof(float);
inline __device__ Smem_tile_softmax_lse(float *smem) : smem_(smem) {
}
inline __device__ void store_pair(const float (&sum)[kMmaM * 2]) {
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
// This makes a difference of 50us for BERT.
// const int warp_idx = threadIdx.x / 32;
const int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
const int lane_idx = threadIdx.x % 32;
const int warp_n = warp_idx / kWarpCountM;
// Extract the position in the warp.
const int row = lane_idx / 4;
if ((lane_idx % 4 == 0) && (warp_n == 0)) {
#pragma unroll
for (int mi = 0; mi < kMmaM; ++mi) {
smem_[mi * kRowsPerMma + row + 0] = sum[mi * 2 + 0];
smem_[mi * kRowsPerMma + row + 8] = sum[mi * 2 + 1];
}
}
}
template<int N>
inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
sum[ni] = smem_[row[ni]];
}
}
float * const smem_;
};
} // namespace fmha

View File

@ -0,0 +1,404 @@
/******************************************************************************
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
#include <cassert>
#include <cstdint>
#include <cstdlib>
#include <ATen/cuda/CUDAContext.h>
// #include <cuda_fp16.h>
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#endif
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace fmha {
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Row {};
struct Col {};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int M, int N >
struct Div_up {
enum { VALUE = (M + N-1) / N };
};
constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Max {
enum { VALUE = A >= B ? A : B };
};
constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B, int C >
struct Max_3 {
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int A, int B >
struct Min {
enum { VALUE = A <= B ? A : B };
};
constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; }
////////////////////////////////////////////////////////////////////////////////////////////////////
template< int SIZE_IN_BYTES >
struct Uint_from_size_in_bytes {
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<1> {
using Type = uint8_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<2> {
using Type = uint16_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<4> {
using Type = uint32_t;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<8> {
using Type = uint2;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Uint_from_size_in_bytes<16> {
using Type = uint4;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template< typename T >
inline __device__ __host__ T div_up(T m, T n) {
return (m + n-1) / n;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t hrelu2(uint32_t x);
template<>
inline __device__ uint32_t hrelu2<__half>(uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
#else
asm volatile( \
"{\n" \
"\t .reg .f16x2 sela;\n" \
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
"\t and.b32 %0, sela, %1;\n"
"}\n" : "=r"(res) : "r"(x), "r"(zero));
#endif
return res;
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template<>
inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) {
uint32_t res;
const uint32_t zero = 0u;
asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
return res;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
static inline __device__ uint16_t float_to_half(float f) {
uint16_t h;
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
return h;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint32_t float2_pack(float a, float b);
template <>
inline __device__ uint32_t float2_pack<__half>(float a, float b) {
__half2 result = __floats2half2_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) {
__nv_bfloat162 result = __floats2bfloat162_rn(a, b);
return reinterpret_cast<uint32_t(&)>(result);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ uint2 float4_pack(float x, float y, float z, float w) {
uint2 d;
d.x = float2_pack<T>(x, y);
d.y = float2_pack<T>(z, w);
return d;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
inline __device__ float2 half2_unpack(uint32_t a);
template <>
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
return __half22float2(reinterpret_cast<__half2 (&)>(a));
}
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
template <>
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// Convert two half2's or bf162's into float, then take their dot product.
template <typename T>
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
float2 af = fmha::half2_unpack<T>(a);
float2 bf = fmha::half2_unpack<T>(b);
return af.x * bf.x + af.y * bf.y;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
template<typename T>
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
float sum;
sum = fmha::hfma2_to_float<T>(a.x, b.x);
sum += fmha::hfma2_to_float<T>(a.y, b.y);
sum += fmha::hfma2_to_float<T>(a.z, b.z);
sum += fmha::hfma2_to_float<T>(a.w, b.w);
return sum;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// L D G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint8_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint16_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
dst = *reinterpret_cast<const uint32_t*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint2 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint2*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void ldg(uint4 &dst, const void *ptr) {
dst = *reinterpret_cast<const uint4*>(ptr);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// S T G
//
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint8_t val) {
*reinterpret_cast<uint8_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint16_t val) {
*reinterpret_cast<uint16_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint32_t val) {
*reinterpret_cast<uint32_t*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint2 val) {
*reinterpret_cast<uint2*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline __device__ void stg(void *ptr, uint4 val) {
*reinterpret_cast<uint4*>(ptr) = val;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ inline T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_reduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
#pragma unroll
for(int mi=0; mi < M; mi++){
dst[mi] = src[mi];
dst[mi] = Allreduce<4>::run(dst[mi], op);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Operator, int M>
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
float tmp[M];
#pragma unroll
for(int mi=0; mi < M; mi++){
tmp[mi] = op(src[mi].x, src[mi].y);
}
quad_allreduce(dst, tmp, op);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace fmha

View File

@ -77,6 +77,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_CUDNN : ${USE_CUDNN}") message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}") message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
message(STATUS " CUDA version : ${CUDA_VERSION}") message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
if(${USE_CUDNN}) if(${USE_CUDNN})
message(STATUS " cuDNN version : ${CUDNN_VERSION}") message(STATUS " cuDNN version : ${CUDNN_VERSION}")
endif() endif()

View File

@ -322,7 +322,7 @@ def get_submodule_folders():
git_modules_path = os.path.join(cwd, ".gitmodules") git_modules_path = os.path.join(cwd, ".gitmodules")
default_modules_path = [os.path.join(third_party_path, name) for name in [ default_modules_path = [os.path.join(third_party_path, name) for name in [
"gloo", "cpuinfo", "tbb", "onnx", "gloo", "cpuinfo", "tbb", "onnx",
"foxi", "QNNPACK", "fbgemm" "foxi", "QNNPACK", "fbgemm", "cutlass"
]] ]]
if not os.path.exists(git_modules_path): if not os.path.exists(git_modules_path):
return default_modules_path return default_modules_path

1
third_party/cutlass vendored Submodule

Submodule third_party/cutlass added at b72cbf957d