mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
flash_attention integration (#81434)
# Summary: - I added a new submodule Cutlass pointing to 2.10 release. The inclusion of flash_attention code should be gated by the flag: USE_FLASH_ATTENTION. This is defaulted to off resulting in flash to not be build anywhere. This is done on purpose since we don't have A100 machines to compile and test on. - Only looked at CMake did not attempt bazel or buck yet. - I included the mha_fwd from flash_attention that has ben refactored to use cutlass 2.10. There is currently no backwards kernel on this branch. That would be a good follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81434 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
219ff26172
commit
0fc02dbba4
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -151,3 +151,6 @@
|
||||
[submodule "third_party/VulkanMemoryAllocator"]
|
||||
path = third_party/VulkanMemoryAllocator
|
||||
url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git
|
||||
[submodule "third_party/cutlass"]
|
||||
path = third_party/cutlass
|
||||
url = https://github.com/NVIDIA/cutlass.git
|
||||
|
@ -721,6 +721,13 @@ set(BUILD_ONEDNN_GRAPH OFF)
|
||||
|
||||
include(cmake/Dependencies.cmake)
|
||||
|
||||
# Moved this cmake set option down here because CMAKE_CUDA_COMPILER_VERSION is not avaialble until now
|
||||
option(USE_FLASH_ATTENTION "Whether to build the flash_attention kernel for scaled dot product attention" OFF)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
ADD_DEFINITIONS(-DUSE_FLASH_ATTENTION)
|
||||
ENDIF()
|
||||
|
||||
|
||||
if(USE_CUDA AND (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 10.2) AND (CMAKE_HOST_SYSTEM_NAME MATCHES "Windows"))
|
||||
# CUDA < 10.2 doesn't support compiling and extracting header dependencies in
|
||||
# one call, so instead CMake calls nvcc twice with && in between.
|
||||
|
@ -130,15 +130,13 @@ file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh")
|
||||
file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp")
|
||||
file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh")
|
||||
file(GLOB native_cudnn_cpp "native/cudnn/*.cpp")
|
||||
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
|
||||
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
|
||||
file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu")
|
||||
file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
|
||||
file(GLOB native_quantized_cuda_cu "native/quantized/cuda/*.cu")
|
||||
file(GLOB native_quantized_cuda_cpp "native/quantized/cuda/*.cpp")
|
||||
file(GLOB native_quantized_cudnn_cpp "native/quantized/cudnn/*.cpp")
|
||||
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
|
||||
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
|
||||
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
|
||||
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
|
||||
|
||||
file(GLOB native_hip_hip "native/hip/*.hip")
|
||||
file(GLOB native_hip_cpp "native/hip/*.cpp")
|
||||
@ -151,11 +149,22 @@ file(GLOB native_sparse_hip_hip "native/sparse/hip/*.hip")
|
||||
file(GLOB native_sparse_hip_cpp "native/sparse/hip/*.cpp")
|
||||
file(GLOB native_quantized_hip_hip "native/quantized/hip/*.hip")
|
||||
file(GLOB native_quantized_hip_cpp "native/quantized/hip/*.cpp")
|
||||
file(GLOB native_transformers_cuda_cu "native/transformers/cuda/*.cu")
|
||||
file(GLOB native_transformers_cuda_cpp "native/transformers/cuda/*.cpp")
|
||||
file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
|
||||
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
|
||||
file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
|
||||
file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
|
||||
# flash_attention sources
|
||||
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
|
||||
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
|
||||
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
||||
endif()
|
||||
|
||||
# XNNPACK
|
||||
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
||||
|
||||
@ -415,6 +424,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
endif()
|
||||
|
||||
if(USE_CUDA AND NOT USE_ROCM)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||
endif()
|
||||
if($ENV{ATEN_STATIC_CUDA})
|
||||
list(APPEND ATen_CUDA_DEPENDENCY_LIBS
|
||||
${CUDA_LIBRARIES}
|
||||
|
@ -13136,6 +13136,11 @@
|
||||
structured: True
|
||||
variants: function
|
||||
|
||||
- func: _flash_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor cum_seq_q, Tensor cum_seq_k, int max_q, int max_k, float dropout_p, bool causal) -> Tensor
|
||||
variants: function
|
||||
dispatch:
|
||||
CUDA: flash_scaled_dot_product_attention
|
||||
|
||||
- func: _transformer_decoder_only_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, Tensor? incr_key=None, Tensor? incr_value=None) -> (Tensor, Tensor, Tensor)
|
||||
variants: function
|
||||
dispatch:
|
||||
|
@ -7,6 +7,7 @@
|
||||
|
||||
#include <c10/util/string_view.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
@ -243,5 +244,196 @@ Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c
|
||||
}
|
||||
return result;
|
||||
}
|
||||
std::tuple<Tensor, int64_t> cumulative_and_max_seq_len(Tensor qkv) {
|
||||
TORCH_CHECK(
|
||||
qkv.is_nested(),
|
||||
"QKV must be nested for flash cumulative_seq_len calculation.")
|
||||
auto* nt_impl = get_nested_tensor_impl(qkv);
|
||||
const auto& sizes = nt_impl->get_nested_size_tensor();
|
||||
auto size_tensor_stride = sizes.stride(0);
|
||||
|
||||
const int64_t batch_size = qkv.size(0);
|
||||
auto cumulative_seqlen = at::zeros(
|
||||
{batch_size + 1}, TensorOptions().device(at::kCPU).dtype(at::kInt));
|
||||
|
||||
auto* sizes_ptr = sizes.data_ptr<int64_t>();
|
||||
auto* cumulative_seqlen_ptr = cumulative_seqlen.data_ptr<int32_t>();
|
||||
|
||||
int32_t sum = 0;
|
||||
int64_t max_seqlen = -1;
|
||||
cumulative_seqlen_ptr[0] = sum;
|
||||
for (const auto i : c10::irange(batch_size)) {
|
||||
// Calculate the cumulative sum of the sequence lengths
|
||||
auto current_seq_len = sizes_ptr[i * size_tensor_stride];
|
||||
sum += current_seq_len;
|
||||
cumulative_seqlen_ptr[i + 1] = sum;
|
||||
|
||||
// Find the max element while we traverse
|
||||
max_seqlen = std::max(max_seqlen, current_seq_len);
|
||||
}
|
||||
// Send to GPU, this is pretty light weight calc for normal batch size
|
||||
// but maybe this needs to be on gpu
|
||||
cumulative_seqlen = cumulative_seqlen.to(TensorOptions().device(at::kCUDA));
|
||||
return std::tuple<Tensor, int64_t>{cumulative_seqlen, max_seqlen};
|
||||
}
|
||||
|
||||
Tensor flash_attention_helper(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool causal) {
|
||||
// Query is of size (batch_size x ragged_seq_len x (3 or 1) x n_heads x
|
||||
// head_did
|
||||
int64_t head_dim{query.size(-1)};
|
||||
int64_t num_heads{query.size(-2)};
|
||||
|
||||
auto cumulative_and_max_q = cumulative_and_max_seq_len(query);
|
||||
Tensor cumulative_sequence_length_q = std::get<0>(cumulative_and_max_q);
|
||||
int64_t max_seqlen_batch_q = std::get<1>(cumulative_and_max_q);
|
||||
|
||||
if (key.is_same(value) || query.is_same(key) || query.is_same(value)) {
|
||||
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
|
||||
|
||||
// For the packed case we need to set the output size for dim 2 to 1
|
||||
auto atten_size = get_nested_size_tensor(query);
|
||||
atten_size.index({at::indexing::Slice(), 1}) = 1;
|
||||
|
||||
auto qkv_buffer_reshaped =
|
||||
get_buffer(query).view({Nnz_q, 3, num_heads, head_dim});
|
||||
|
||||
// If we are passing in query, key, value all the same tensors than we have
|
||||
// packed them into one tensor and need to slice for flash attention
|
||||
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
|
||||
qkv_buffer_reshaped.index({at::indexing::Slice(), 0}),
|
||||
qkv_buffer_reshaped.index({at::indexing::Slice(), 1}),
|
||||
qkv_buffer_reshaped.index({at::indexing::Slice(), 2}),
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_q,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_q,
|
||||
dropout_p,
|
||||
causal);
|
||||
// Output of flash_attention is a regular tensor lets wrap it back up to
|
||||
// form a nested tensor
|
||||
return wrap_buffer(atten_buffer.view(-1), atten_size);
|
||||
}
|
||||
|
||||
// Query, Key, and Value are not all the same tensor and therefore need to
|
||||
// calculate K meta data
|
||||
|
||||
// The nested tensors will be of shape {Batch_size x ragged_seq_len x
|
||||
// num_heads * head_dim }
|
||||
auto cumulative_and_max_k = cumulative_and_max_seq_len(key);
|
||||
Tensor cumulative_sequence_length_k = std::get<0>(cumulative_and_max_k);
|
||||
int64_t max_seqlen_batch_k = std::get<1>(cumulative_and_max_k);
|
||||
|
||||
// K and V have to have the same Nnz, should probably torch_check before now
|
||||
// assume in order to not iterate over v
|
||||
int64_t Nnz_q{cumulative_sequence_length_q[-1].item<int64_t>()};
|
||||
int64_t Nnz_kv{cumulative_sequence_length_k[-1].item<int64_t>()};
|
||||
|
||||
auto query_buffer_reshaped =
|
||||
get_buffer(query).view({Nnz_q, num_heads, head_dim});
|
||||
auto key_buffer_reshaped =
|
||||
get_buffer(key).view({Nnz_kv, num_heads, head_dim});
|
||||
auto value_buffer_reshaped =
|
||||
get_buffer(value).view({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
|
||||
query_buffer_reshaped,
|
||||
key_buffer_reshaped,
|
||||
value_buffer_reshaped,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
causal);
|
||||
// Output of flash_attention is a regular tensor lets wrap it back up to
|
||||
// form a nested tensor, the size of which should match the query tensor
|
||||
return wrap_buffer(atten_buffer.view(-1), get_nested_size_tensor(query));
|
||||
}
|
||||
|
||||
Tensor flash_attention_helper_dense(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool causal) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
!query.is_nested() && !key.is_nested() && !value.is_nested());
|
||||
// Query is of size (batch_size x dense_seq_len x 3 x n_heads
|
||||
// head_dim)
|
||||
const auto batch_size = query.size(0);
|
||||
auto max_seqlen_batch_q = query.size(1);
|
||||
int64_t head_dim{query.size(-1)};
|
||||
int64_t num_heads{query.size(-2)};
|
||||
|
||||
auto cumulative_sequence_length_q = at::arange(
|
||||
0,
|
||||
(batch_size + 1) * max_seqlen_batch_q,
|
||||
max_seqlen_batch_q,
|
||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||
int64_t Nnz_q{batch_size * max_seqlen_batch_q};
|
||||
|
||||
if (key.is_same(value) || query.is_same(key) || query.is_same(value)) {
|
||||
// In the dense case flash attention expects an input that is
|
||||
// (b*s) x num_heads x head_dim
|
||||
auto query_reshaped = query.reshape({Nnz_q, 3, num_heads, head_dim});
|
||||
// If we are passing in query, key, value all the same tensors than we have
|
||||
// packed them into one tensor and need to slice for flash attention
|
||||
|
||||
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
|
||||
query_reshaped.index({at::indexing::Slice(), 0}),
|
||||
query_reshaped.index({at::indexing::Slice(), 1}),
|
||||
query_reshaped.index({at::indexing::Slice(), 2}),
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_q,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_q,
|
||||
dropout_p,
|
||||
causal);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
return atten_buffer.reshape(
|
||||
{batch_size, max_seqlen_batch_q, num_heads, head_dim});
|
||||
}
|
||||
|
||||
// Query, Key, and Value are not all the same tensor and therefore need to
|
||||
// calculate K meta data
|
||||
auto max_seqlen_batch_k = key.size(1);
|
||||
auto cumulative_sequence_length_k = at::arange(
|
||||
0,
|
||||
(batch_size + 1) * max_seqlen_batch_k,
|
||||
max_seqlen_batch_k,
|
||||
TensorOptions().device(at::kCUDA).dtype(at::kInt));
|
||||
|
||||
// K and V have to have the same Nnz, should probably torch_check before
|
||||
// assume for now in order to not iterate over v
|
||||
int64_t Nnz_kv{batch_size * max_seqlen_batch_k};
|
||||
|
||||
// Calculate head dim
|
||||
TORCH_INTERNAL_ASSERT(query.size(-1) == key.size(-1));
|
||||
TORCH_INTERNAL_ASSERT(query.size(-1) == value.size(-1));
|
||||
|
||||
auto query_reshaped = query.reshape({Nnz_q, num_heads, head_dim});
|
||||
auto key_reshaped = key.reshape({Nnz_kv, num_heads, head_dim});
|
||||
auto value_reshaped = value.reshape({Nnz_kv, num_heads, head_dim});
|
||||
|
||||
Tensor atten_buffer = at::_flash_scaled_dot_product_attention(
|
||||
query_reshaped,
|
||||
key_reshaped,
|
||||
value_reshaped,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
causal);
|
||||
// Reshape output to convert nnz to batch_size and seq_len
|
||||
return atten_buffer.reshape(
|
||||
{batch_size, max_seqlen_batch_q, num_heads, head_dim});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -83,5 +83,19 @@ void add_padding_kernelLauncher(
|
||||
const std::vector<int64_t>& output_sizes,
|
||||
const int batch_size,
|
||||
const int output_batch_size);
|
||||
|
||||
Tensor flash_attention_helper_dense(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool causal);
|
||||
|
||||
Tensor flash_attention_helper(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
double dropout_p,
|
||||
bool causal);
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
@ -1,4 +1,5 @@
|
||||
#include <type_traits>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
@ -9,10 +10,18 @@
|
||||
#include <ATen/ops/_nested_from_padded.h>
|
||||
#endif
|
||||
|
||||
// TODO Consider moving all flash_attention code, nested tensor included to
|
||||
// Transformer library
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
#endif
|
||||
|
||||
#include <ATen/native/nested/NestedTensorTransformerFunctions.h>
|
||||
#include <ATen/native/nested/NestedTensorMath.h>
|
||||
#include <ATen/native/nested/NestedTensorUtils.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
@ -207,5 +216,37 @@ Tensor NestedTensor_to_padded_tensor_cuda(
|
||||
return NestedTensor_to_padded_tensor_generic(t, padding, output_size);
|
||||
}
|
||||
|
||||
Tensor flash_scaled_dot_product_attention(
|
||||
const Tensor& query,
|
||||
const Tensor& key,
|
||||
const Tensor& value,
|
||||
const Tensor& cumulative_sequence_length_q,
|
||||
const Tensor& cumulative_sequence_length_k,
|
||||
const int64_t max_seqlen_batch_q,
|
||||
const int64_t max_seqlen_batch_k,
|
||||
double dropout_p,
|
||||
bool causal) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
auto softmax_scale = std::pow(query.size(-1), -0.5);
|
||||
std::vector<Tensor> output = fmha::mha_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
cumulative_sequence_length_q,
|
||||
cumulative_sequence_length_k,
|
||||
max_seqlen_batch_q,
|
||||
max_seqlen_batch_k,
|
||||
dropout_p,
|
||||
softmax_scale,
|
||||
false,
|
||||
causal,
|
||||
false,
|
||||
c10::nullopt);
|
||||
return output[0];
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
return Tensor{};
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
149
aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h
Normal file
149
aten/src/ATen/native/transformers/cuda/flash_attn/epilogue.h
Normal file
@ -0,0 +1,149 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/layout/layout.h>
|
||||
#include <third_party/cutlass/include/cutlass/arch/mma.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/epilogue_predicated_tile_iterator.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename MmaCore>
|
||||
struct FMHAEpilogue {
|
||||
|
||||
using ThreadblockShape = typename MmaCore::Shape;
|
||||
using WarpMma = typename MmaCore::MmaTensorOp;
|
||||
using LayoutC = typename MmaCore::LayoutC;
|
||||
using Element = typename MmaCore::ElementA;
|
||||
using ElementC = typename MmaCore::ElementC;
|
||||
|
||||
static constexpr int kPartitionsK = ThreadblockShape::kK / MmaCore::WarpShape::kK;
|
||||
|
||||
using AccumulatorFragmentIterator = cutlass::epilogue::warp::FragmentIteratorTensorOp<
|
||||
typename WarpMma::Shape,
|
||||
typename WarpMma::Policy::Operator::Shape,
|
||||
typename WarpMma::Policy::Operator::ElementC,
|
||||
typename WarpMma::Policy::Operator::FragmentC,
|
||||
LayoutC>;
|
||||
using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile;
|
||||
static constexpr int kIterationsStore = AccumulatorFragmentIterator::kIterations;
|
||||
|
||||
// Maybe elementsPerAccess should vary: 4 for d=64, 2 for d=32?
|
||||
using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
||||
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, Element, /*ElementsPerAccess=*/4>::Type;
|
||||
using OutputTileThreadMapAccum = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOp<
|
||||
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, ElementC, /*ElementsPerAccess=*/4>::Type;
|
||||
|
||||
using GmemIterator = fmha::EpiloguePredicatedTileIterator<
|
||||
OutputTileThreadMap,
|
||||
Element
|
||||
>;
|
||||
// which ThreadMap should we use?
|
||||
using GmemIteratorAccum = fmha::EpiloguePredicatedTileIterator<
|
||||
// OutputTileThreadMapAccum,
|
||||
OutputTileThreadMap,
|
||||
ElementC
|
||||
>;
|
||||
|
||||
|
||||
using DefaultIterators = cutlass::epilogue::threadblock::detail::DefaultIteratorsTensorOp<
|
||||
Element, ElementC, /*ElementsPerAccess=*/4, ThreadblockShape, typename WarpMma::Shape,
|
||||
typename WarpMma::Policy::Operator::Shape, typename OutputTileThreadMap::CompactedThreadMap>;
|
||||
using WarpTileIterator = typename DefaultIterators::WarpTileIterator;
|
||||
static_assert(WarpTileIterator::kIterations == kIterationsStore);
|
||||
using SharedLoadIterator = typename DefaultIterators::SharedLoadIterator;
|
||||
using OutputFragment = typename SharedLoadIterator::Fragment;
|
||||
|
||||
// using Padding = cutlass::MatrixShape<0, 0>;
|
||||
using Padding = cutlass::MatrixShape<0, 64 / cutlass::sizeof_bits<ElementC>::value * 4>;
|
||||
static constexpr int kFragmentsPerIteration = kIterationsStore; // TODO: could be 1 for Volta?
|
||||
/*Using kIterationsStore here so that we get the right storage size*/
|
||||
using EpilogueBase = typename cutlass::epilogue::threadblock::EpilogueBase<
|
||||
ThreadblockShape, typename WarpMma::Shape, kPartitionsK, AccumulatorFragmentIterator, WarpTileIterator,
|
||||
Padding, kIterationsStore>;
|
||||
|
||||
using SharedStorage = typename EpilogueBase::SharedStorage;
|
||||
static constexpr int kSmemTiles = EpilogueBase::kFragmentsPerIteration;
|
||||
static constexpr int kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles;
|
||||
static constexpr int kSmemPointerOffsetPerWarp = SharedStorage::StorageShape::kCount / (kSmemTiles * kPartitionsK);
|
||||
|
||||
SharedStorage *shared_storage;
|
||||
WarpTileIterator warp_tile_iterator;
|
||||
|
||||
inline __device__ FMHAEpilogue(void *smem, const int tidx)
|
||||
: shared_storage(reinterpret_cast<SharedStorage *>(smem))
|
||||
, warp_tile_iterator(shared_storage->reference(), threadIdx.x % 32) {
|
||||
|
||||
// const int warp_idx = tidx / 32;
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
// https://github.com/NVIDIA/cutlass/blob/e66bfcb1f880792caa46b1e983c4114e23afa5f3/include/cutlass/gemm/kernel/gemm_with_fused_epilogue.h#L520
|
||||
const int warp_idx = __shfl_sync(0xffffffff, tidx / 32, 0);
|
||||
|
||||
cutlass::MatrixCoord warp_offset{kIterationsStore * warp_idx, 0};
|
||||
|
||||
warp_tile_iterator.add_tile_offset(warp_offset);
|
||||
}
|
||||
|
||||
// Store the accumulators.
|
||||
inline __device__ void store(const AccumulatorTile &acc) {
|
||||
AccumulatorFragmentIterator accum_fragment_iterator(acc);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < kIterationsStore; ++p) {
|
||||
typename AccumulatorFragmentIterator::Fragment accum_fragment;
|
||||
accum_fragment_iterator.load(accum_fragment);
|
||||
++accum_fragment_iterator;
|
||||
|
||||
warp_tile_iterator.store(accum_fragment);
|
||||
if (p < kIterationsStore - 1) {
|
||||
warp_tile_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp);
|
||||
}
|
||||
}
|
||||
if (kIterationsStore > 1) {
|
||||
warp_tile_iterator.add_pointer_offset((1 - kIterationsStore) * kSmemPointerOffsetPerWarp);
|
||||
}
|
||||
}
|
||||
|
||||
// Load the accumulators
|
||||
template<bool zero_init=true>
|
||||
inline __device__ void load(OutputFragment (&out)[kFragmentsPerIteration],
|
||||
const int tidx) {
|
||||
SharedLoadIterator shared_load_iterator(shared_storage->reference(), tidx);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int p = 0; p < EpilogueBase::kFragmentsPerIteration; ++p) {
|
||||
OutputFragment aligned_accum_fragment[kPartitionsK];
|
||||
shared_load_iterator.load(aligned_accum_fragment[0]);
|
||||
cutlass::plus<OutputFragment> add_fragments;
|
||||
if (kPartitionsK > 1) {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for ( int i = 1; i < kPartitionsK; ++i) {
|
||||
shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp * kIterationsStore);
|
||||
shared_load_iterator.load(aligned_accum_fragment[i]);
|
||||
aligned_accum_fragment[0] = add_fragments(aligned_accum_fragment[0], aligned_accum_fragment[i]);
|
||||
}
|
||||
shared_load_iterator.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffsetPerWarp * kIterationsStore);
|
||||
}
|
||||
if (p < EpilogueBase::kFragmentsPerIteration - 1) {
|
||||
shared_load_iterator.add_pointer_offset(kSmemPointerOffsetPerWarp);
|
||||
}
|
||||
|
||||
out[p] = zero_init ? aligned_accum_fragment[0] : add_fragments(out[p], aligned_accum_fragment[0]);
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,493 @@
|
||||
// Adapted from cutlass/epilogue/threadblock/predicated_tile_iterator.h
|
||||
// We just want to add the move() function, but idk how to do it without
|
||||
// copying the code here.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/arch/arch.h>
|
||||
#include <third_party/cutlass/include/cutlass/arch/memory.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/output_tile_thread_map.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator_params.h>
|
||||
#include <third_party/cutlass/include/cutlass/layout/matrix.h>
|
||||
#include <third_party/cutlass/include/cutlass/layout/tensor.h>
|
||||
#include <third_party/cutlass/include/cutlass/matrix_shape.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
#include <third_party/cutlass/include/cutlass/tensor_ref.h>
|
||||
#include <third_party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h>
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
using namespace cutlass;
|
||||
using namespace cutlass::epilogue::threadblock;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Tile iterator used to load and store output tile from global memory in epilogue.
|
||||
///
|
||||
/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator
|
||||
///
|
||||
template <
|
||||
typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap)
|
||||
typename Element_, ///< Element data type
|
||||
bool ScatterD = false, ///< Scatter D operand or not
|
||||
bool UseCUDAStore = false
|
||||
>
|
||||
class EpiloguePredicatedTileIterator {
|
||||
public:
|
||||
using ThreadMap = ThreadMap_;
|
||||
using Shape = typename ThreadMap::Shape;
|
||||
|
||||
using Element = Element_;
|
||||
|
||||
using Layout = layout::RowMajor;
|
||||
using TensorRef = TensorRef<Element, Layout>;
|
||||
using ConstTensorRef = typename TensorRef::ConstTensorRef;
|
||||
|
||||
using Index = typename Layout::Index;
|
||||
using LongIndex = typename Layout::LongIndex;
|
||||
using TensorCoord = MatrixCoord;
|
||||
|
||||
static int const kElementsPerAccess = ThreadMap::kElementsPerAccess;
|
||||
static int const kThreads = ThreadMap::kThreads;
|
||||
static int const kIterations = ThreadMap::Count::kTile;
|
||||
|
||||
static_assert( ThreadMap::Iterations::kRow > 0,"ThreadMap::Iterations::kRow must be > 0");
|
||||
static_assert( ThreadMap::Iterations::kGroup > 0,"ThreadMap::Iterations::kGroup must be > 0");
|
||||
static_assert( ThreadMap::Iterations::kCluster > 0,"ThreadMap::Iterations::kCluster must be > 0");
|
||||
static_assert( ThreadMap::Iterations::kColumn > 0,"ThreadMap::Iterations::kColumn must be > 0");
|
||||
|
||||
/// Fragment object
|
||||
using Fragment = Array<
|
||||
Element,
|
||||
ThreadMap::Iterations::kColumn *
|
||||
ThreadMap::Iterations::kRow *
|
||||
ThreadMap::Iterations::kGroup *
|
||||
ThreadMap::Iterations::kCluster * ThreadMap::kElementsPerAccess>;
|
||||
|
||||
/// Memory access size
|
||||
using AccessType = AlignedArray<Element, ThreadMap::kElementsPerAccess>;
|
||||
|
||||
//
|
||||
// Parameters struct
|
||||
//
|
||||
|
||||
/// Uses a non-template class
|
||||
struct Params : PredicatedTileIteratorParams {
|
||||
using Base = PredicatedTileIteratorParams;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params() { }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Layout const &layout):
|
||||
PredicatedTileIteratorParams(
|
||||
layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess,
|
||||
make_OutputTileThreadMapDesc<ThreadMap>()
|
||||
)
|
||||
{ }
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Base const &base) :
|
||||
Base(base) { }
|
||||
};
|
||||
|
||||
/// Mask object
|
||||
struct Mask {
|
||||
|
||||
static int const kCount = ThreadMap::Iterations::kColumn;
|
||||
|
||||
/// Predicate state
|
||||
bool predicates[kCount];
|
||||
|
||||
//
|
||||
// Mask
|
||||
//
|
||||
CUTLASS_HOST_DEVICE
|
||||
Mask() {
|
||||
enable();
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_HOST_DEVICE void clear() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = false;
|
||||
}
|
||||
}
|
||||
|
||||
///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable() {
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kCount; ++i) {
|
||||
predicates[i] = true;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Data members
|
||||
//
|
||||
|
||||
/// Parameters structure containing reference and precomputed state.
|
||||
PredicatedTileIteratorParams params_;
|
||||
|
||||
/// Byte-level pointer
|
||||
uint8_t *byte_pointer_;
|
||||
|
||||
/// Array of boolean values to contain steady-state predicates
|
||||
Mask mask_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_row_;
|
||||
|
||||
/// Extent of the matrix tile in rows
|
||||
Index extent_column_;
|
||||
|
||||
/// A thread's starting row position (assuming steady-state predicates have been computed)
|
||||
Index thread_start_row_;
|
||||
|
||||
/// A thread's starting column
|
||||
Index thread_start_column_;
|
||||
|
||||
/// Internal state counter
|
||||
int state_[3];
|
||||
|
||||
/// Scatter indices
|
||||
int const *indices_;
|
||||
|
||||
//
|
||||
// Static asserts about internal strides
|
||||
//
|
||||
|
||||
static_assert(sizeof(extent_row_) == 4, "Expected 32b extents");
|
||||
static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents");
|
||||
static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides");
|
||||
|
||||
private:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
public:
|
||||
|
||||
//
|
||||
// Methods
|
||||
//
|
||||
|
||||
/// Constructor
|
||||
CUTLASS_DEVICE
|
||||
EpiloguePredicatedTileIterator(
|
||||
PredicatedTileIteratorParams const & params,
|
||||
Element *pointer,
|
||||
TensorCoord extent,
|
||||
int thread_idx,
|
||||
TensorCoord threadblock_offset = TensorCoord(),
|
||||
int const *indices = nullptr
|
||||
):
|
||||
params_(params), indices_(indices)
|
||||
{
|
||||
|
||||
TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset;
|
||||
|
||||
extent_row_ = extent.row();
|
||||
extent_column_ = extent.column();
|
||||
|
||||
thread_start_row_ = thread_offset.row();
|
||||
thread_start_column_ = thread_offset.column();
|
||||
|
||||
// Initialize predicates
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) {
|
||||
|
||||
mask_.predicates[c] = ((thread_offset.column()
|
||||
+ ThreadMap::Delta::kColumn * c) < extent.column());
|
||||
}
|
||||
|
||||
// Null pointer performs no accesses
|
||||
if (!pointer) {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
if (ScatterD && !indices) {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
// Initialize pointer
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.row()) * LongIndex(params_.stride) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
|
||||
if (ScatterD) {
|
||||
byte_pointer_ = reinterpret_cast<uint8_t *>(pointer) +
|
||||
LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess;
|
||||
}
|
||||
|
||||
// Initialize internal state counter
|
||||
state_[0] = state_[1] = state_[2] = 0;
|
||||
}
|
||||
|
||||
/// Adds a pointer offset in units of Element
|
||||
CUTLASS_HOST_DEVICE
|
||||
void add_pointer_offset(LongIndex pointer_offset) {
|
||||
byte_pointer_ += pointer_offset * sizeof_bits<Element>::value / 8;
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load_with_byte_offset(Fragment &frag, int64_t byte_offset) const {
|
||||
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
AccessType *frag_ptr = reinterpret_cast<AccessType *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
if (ScatterD && row_guard) {
|
||||
assert(indices_);
|
||||
|
||||
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset +
|
||||
LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride));
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
cutlass::arch::global_load<
|
||||
AccessType,
|
||||
sizeof(AccessType)
|
||||
>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn +
|
||||
column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn /
|
||||
kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Loads a fragment from memory
|
||||
CUTLASS_DEVICE
|
||||
void load(Fragment &frag) const {
|
||||
|
||||
load_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store_with_byte_offset(Fragment const &frag, int64_t byte_offset) const {
|
||||
uint8_t *byte_pointer = byte_pointer_;
|
||||
AccessType const *frag_ptr = reinterpret_cast<AccessType const *>(&frag);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) {
|
||||
|
||||
int frag_row_idx =
|
||||
(row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster));
|
||||
|
||||
int row_offset = row * ThreadMap::Delta::kRow
|
||||
+ group * ThreadMap::Delta::kGroup
|
||||
+ cluster * ThreadMap::Delta::kCluster;
|
||||
|
||||
bool row_guard = ((row_offset + thread_start_row_) < extent_row_);
|
||||
|
||||
AccessType *memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset);
|
||||
|
||||
if (ScatterD && row_guard) {
|
||||
assert(indices_);
|
||||
|
||||
memory_pointer = reinterpret_cast<AccessType *>(byte_pointer + byte_offset +
|
||||
LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride));
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) {
|
||||
|
||||
bool guard = row_guard && mask_.predicates[column];
|
||||
|
||||
if (UseCUDAStore) {
|
||||
if (guard) {
|
||||
memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] =
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column];
|
||||
}
|
||||
} else {
|
||||
cutlass::arch::global_store<AccessType, sizeof(AccessType)>(
|
||||
frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column],
|
||||
(void *)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess],
|
||||
guard);
|
||||
}
|
||||
}
|
||||
|
||||
if (row + 1 < ThreadMap::Iterations::kRow) {
|
||||
if (!ScatterD) {
|
||||
byte_pointer += params_.increment_row;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if (group + 1 < ThreadMap::Iterations::kGroup) {
|
||||
byte_pointer += params_.increment_group;
|
||||
}
|
||||
}
|
||||
|
||||
if (cluster + 1 < ThreadMap::Iterations::kCluster) {
|
||||
byte_pointer += params_.increment_cluster;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Stores a fragment to memory
|
||||
CUTLASS_DEVICE
|
||||
void store(Fragment const &frag) const {
|
||||
|
||||
store_with_byte_offset(frag, 0);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MatrixCoord thread_start() const {
|
||||
return MatrixCoord(thread_start_row_, thread_start_column_);
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_row() const {
|
||||
return thread_start_row_;
|
||||
}
|
||||
|
||||
/// Need to get the thread start row from the tile iterator
|
||||
CUTLASS_DEVICE
|
||||
int32_t thread_start_column() const {
|
||||
return thread_start_column_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in rows
|
||||
CUTLASS_DEVICE
|
||||
Index extent_row() const {
|
||||
return extent_row_;
|
||||
}
|
||||
|
||||
/// Extent of the matrix in columns
|
||||
CUTLASS_DEVICE
|
||||
Index extent_column() const {
|
||||
return extent_column_;
|
||||
}
|
||||
|
||||
/// Advances to the next position to load or store
|
||||
CUTLASS_HOST_DEVICE
|
||||
void move(const int step=1) {
|
||||
|
||||
if (!ScatterD) {
|
||||
byte_pointer_ += step * params_.advance_row;
|
||||
}
|
||||
|
||||
thread_start_row_ += step * ThreadMap::Shape::kRow;
|
||||
}
|
||||
|
||||
///< Efficiently disables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void clear_mask() {
|
||||
mask_.clear();
|
||||
}
|
||||
|
||||
///< Efficiently enables all accesses guarded by mask
|
||||
CUTLASS_DEVICE void enable_mask() {
|
||||
mask_.enable();
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void get_mask(Mask &mask) const {
|
||||
mask = mask_;
|
||||
}
|
||||
|
||||
///< Sets the mask
|
||||
CUTLASS_DEVICE void set_mask(Mask const &mask) {
|
||||
mask_ = mask;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
} // namespace fmha
|
154
aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h
Normal file
154
aten/src/ATen/native/transformers/cuda/flash_attn/fmha.h
Normal file
@ -0,0 +1,154 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/cuda/CUDAGeneratorImpl.h>
|
||||
#include <ATen/cuda/CUDAGraphsUtils.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_utils.h>
|
||||
|
||||
|
||||
constexpr int TOTAL_DIM = 0;
|
||||
constexpr int H_DIM = 1;
|
||||
constexpr int D_DIM = 2;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Qkv_params {
|
||||
// The QKV matrices.
|
||||
void *__restrict__ q_ptr;
|
||||
void *__restrict__ k_ptr;
|
||||
void *__restrict__ v_ptr;
|
||||
|
||||
// The stride between rows of the Q, K and V matrices.
|
||||
// size_t qkv_stride_in_elts;
|
||||
// size_t qkv_stride_in_bytes;
|
||||
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
|
||||
// The code probably won't work for arrays larger than 2GB.
|
||||
uint32_t q_row_stride_in_elts;
|
||||
uint32_t k_row_stride_in_elts;
|
||||
uint32_t v_row_stride_in_elts;
|
||||
uint32_t q_head_stride_in_elts;
|
||||
uint32_t k_head_stride_in_elts;
|
||||
uint32_t v_head_stride_in_elts;
|
||||
|
||||
// The number of heads.
|
||||
int h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct FMHA_fprop_params : public Qkv_params {
|
||||
|
||||
// The O matrix (output).
|
||||
void * __restrict__ o_ptr;
|
||||
|
||||
// The stride between rows of O.
|
||||
// size_t o_stride_in_elts;
|
||||
// size_t o_stride_in_bytes;
|
||||
uint32_t o_row_stride_in_elts;
|
||||
uint32_t o_head_stride_in_elts;
|
||||
|
||||
// The pointer to the O_tmp matrix, which holds O intermediate value during
|
||||
// the loop;
|
||||
void *__restrict__ o_tmp_ptr;
|
||||
|
||||
// The pointer to the S matrix.
|
||||
void * __restrict__ s_ptr;
|
||||
// The stride between rows of the S matrix.
|
||||
// int64_t s_stride_in_bytes;
|
||||
uint32_t s_stride_in_bytes;
|
||||
|
||||
// The pointer to the softmax sum.
|
||||
void * __restrict__ softmax_lse_ptr;
|
||||
|
||||
// The dimensions.
|
||||
int b, seqlen_q, seqlen_k, d;
|
||||
|
||||
// The scaling factors for the kernel.
|
||||
float scale_bmm1;
|
||||
|
||||
// array of length b+1 holding starting offset of each sequence.
|
||||
int * __restrict__ cu_seqlens_q;
|
||||
int * __restrict__ cu_seqlens_k;
|
||||
|
||||
int *__restrict__ blockmask;
|
||||
|
||||
// The dropout probability (probability of keeping an activation).
|
||||
float p_dropout;
|
||||
uint32_t p_dropout_in_uint;
|
||||
uint16_t p_dropout_in_uint16_t;
|
||||
|
||||
// Scale factor of 1 / (1 - p_dropout).
|
||||
float rp_dropout;
|
||||
float scale_bmm1_rp_dropout;
|
||||
|
||||
// Random state.
|
||||
at::PhiloxCudaState philox_args;
|
||||
|
||||
bool is_bf16;
|
||||
bool is_causal;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_params>
|
||||
struct Launch_params{
|
||||
Launch_params(cudaDeviceProp * props_,
|
||||
cudaStream_t stream_,
|
||||
bool is_dropout_,
|
||||
bool return_softmax_)
|
||||
: elts_per_thread(0)
|
||||
, props(props_)
|
||||
, stream(stream_)
|
||||
, is_dropout(is_dropout_)
|
||||
, return_softmax(return_softmax_) {
|
||||
}
|
||||
|
||||
size_t elts_per_thread;
|
||||
|
||||
cudaDeviceProp * props;
|
||||
|
||||
cudaStream_t stream;
|
||||
|
||||
bool is_dropout;
|
||||
bool return_softmax;
|
||||
|
||||
Kernel_params params;
|
||||
int num_full_heads;
|
||||
int num_main_groups;
|
||||
int heads_last_wave;
|
||||
int main_steps;
|
||||
int rest_steps;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params, const bool configure);
|
244
aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
Normal file
244
aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.cpp
Normal file
@ -0,0 +1,244 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == at::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
namespace fmha {
|
||||
|
||||
void set_params_fprop(FMHA_fprop_params ¶ms,
|
||||
// sizes
|
||||
const size_t b,
|
||||
const size_t seqlen_q,
|
||||
const size_t seqlen_k,
|
||||
const size_t h,
|
||||
const size_t d,
|
||||
// device pointers
|
||||
const at::Tensor q,
|
||||
const at::Tensor k,
|
||||
const at::Tensor v,
|
||||
void *cu_seqlens_q_d,
|
||||
void *cu_seqlens_k_d,
|
||||
void *o_packed_d,
|
||||
void *o_tmp_d,
|
||||
void *s_d,
|
||||
void *softmax_lse_d,
|
||||
float p_dropout,
|
||||
float softmax_scale,
|
||||
bool is_causal) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.is_bf16 = q.dtype() == at::kBFloat16;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.q_ptr = q.data_ptr();
|
||||
params.k_ptr = k.data_ptr();
|
||||
params.v_ptr = v.data_ptr();
|
||||
params.q_row_stride_in_elts = q.stride(0);
|
||||
params.k_row_stride_in_elts = k.stride(0);
|
||||
params.v_row_stride_in_elts = v.stride(0);
|
||||
params.q_head_stride_in_elts = q.stride(1);
|
||||
params.k_head_stride_in_elts = k.stride(1);
|
||||
params.v_head_stride_in_elts = v.stride(1);
|
||||
params.o_ptr = o_packed_d;
|
||||
params.o_row_stride_in_elts = h * d;
|
||||
params.o_head_stride_in_elts = d;
|
||||
params.o_tmp_ptr = o_tmp_d;
|
||||
|
||||
params.cu_seqlens_q = static_cast<int *>(cu_seqlens_q_d);
|
||||
params.cu_seqlens_k = static_cast<int *>(cu_seqlens_k_d);
|
||||
|
||||
// S = softmax(P)
|
||||
params.s_ptr = s_d;
|
||||
params.s_stride_in_bytes = b * h * seqlen_k * 2; // 2 = sizeof(Element)
|
||||
|
||||
// Softmax sum
|
||||
params.softmax_lse_ptr = softmax_lse_d;
|
||||
|
||||
// Set the dimensions.
|
||||
params.b = b;
|
||||
params.h = h;
|
||||
params.seqlen_q = seqlen_q;
|
||||
params.seqlen_k = seqlen_k;
|
||||
params.d = d;
|
||||
|
||||
// Set the different scale values.
|
||||
params.scale_bmm1 = softmax_scale;
|
||||
|
||||
// Set this to probability of keeping an element to simplify things.
|
||||
params.p_dropout = 1.f - p_dropout;
|
||||
// Convert p from float to int so we don't have to convert the random uint to float to compare.
|
||||
// [Minor] We want to round down since when we do the comparison we use <= instead of <
|
||||
params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0));
|
||||
params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0));
|
||||
params.rp_dropout = 1.f / params.p_dropout;
|
||||
params.scale_bmm1_rp_dropout = params.rp_dropout * params.scale_bmm1;
|
||||
TORCH_CHECK(p_dropout < 1.f);
|
||||
|
||||
params.is_causal = is_causal;
|
||||
}
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_) {
|
||||
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
bool is_sm75 = dprops->major == 7 && dprops->minor == 5;
|
||||
bool is_sm80 = dprops->major == 8 && dprops->minor == 0;
|
||||
bool is_sm8x = dprops->major == 8 && dprops->minor >= 0;
|
||||
TORCH_CHECK(is_sm8x || is_sm75);
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
bool is_dropout = p_dropout > 0.0;
|
||||
Launch_params<FMHA_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
|
||||
|
||||
auto q_dtype = q.dtype();
|
||||
TORCH_CHECK(q_dtype == at::kHalf || (is_sm8x && q_dtype == at::kBFloat16));
|
||||
TORCH_CHECK(k.dtype() == q_dtype);
|
||||
TORCH_CHECK(v.dtype() == q_dtype);
|
||||
TORCH_CHECK(cu_seqlens_q.dtype() == at::kInt);
|
||||
TORCH_CHECK(cu_seqlens_k.dtype() == at::kInt);
|
||||
|
||||
TORCH_CHECK(q.is_cuda());
|
||||
TORCH_CHECK(k.is_cuda());
|
||||
TORCH_CHECK(v.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_q.is_cuda());
|
||||
TORCH_CHECK(cu_seqlens_k.is_cuda());
|
||||
|
||||
TORCH_CHECK(q.stride(-1) == 1);
|
||||
TORCH_CHECK(k.stride(-1) == 1);
|
||||
TORCH_CHECK(v.stride(-1) == 1);
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
TORCH_CHECK(cu_seqlens_k.is_contiguous());
|
||||
|
||||
const auto sizes = q.sizes();
|
||||
|
||||
const int batch_size = cu_seqlens_q.numel() - 1;
|
||||
const int total_q = sizes[TOTAL_DIM];
|
||||
const int num_heads = sizes[H_DIM];
|
||||
const int head_size = sizes[D_DIM];
|
||||
const int total_k = k.size(TOTAL_DIM);
|
||||
TORCH_CHECK(batch_size > 0);
|
||||
TORCH_CHECK((head_size % 8 == 0) && (head_size <= 128));
|
||||
const int head_size_rounded = head_size <= 64 ? 64 : 128;
|
||||
|
||||
CHECK_SHAPE(q, total_q, num_heads, head_size);
|
||||
CHECK_SHAPE(k, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(v, total_k, num_heads, head_size);
|
||||
CHECK_SHAPE(cu_seqlens_q, batch_size + 1);
|
||||
CHECK_SHAPE(cu_seqlens_k, batch_size + 1);
|
||||
|
||||
int blocksize_c = ((head_size_rounded == 128 && (is_dropout || !is_sm80)) || (is_sm75 && head_size_rounded == 64 && is_dropout)) ? 128 : 256;
|
||||
// Need to round max_seqlen_k to multiples of blocksize_c
|
||||
int max_seqlen_k = ((max_seqlen_k_ + blocksize_c - 1) / blocksize_c) * blocksize_c;
|
||||
if( max_seqlen_k_ <= 128 ) {
|
||||
max_seqlen_k = 128;
|
||||
} else if( max_seqlen_k_ <= 256 ) {
|
||||
max_seqlen_k = 256;
|
||||
}
|
||||
int max_seqlen_q = ((max_seqlen_q_ + 16 - 1) / 16) * 16;
|
||||
bool loop = max_seqlen_k > blocksize_c;
|
||||
|
||||
auto opts = q.options();
|
||||
|
||||
auto o = at::empty({ total_q, num_heads, head_size }, opts);
|
||||
|
||||
at::Tensor o_tmp;
|
||||
if (loop) { o_tmp = at::empty({total_q, num_heads, head_size}, opts.dtype(at::kFloat)); }
|
||||
|
||||
auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
|
||||
// auto softmax_lse = torch::full({batch_size, num_heads, max_seqlen_k}, -std::numeric_limits<float>::infinity(), opts.dtype(at::kFloat));
|
||||
|
||||
at::Tensor s;
|
||||
if (return_softmax) { s = at::empty({ batch_size, num_heads, max_seqlen_q, max_seqlen_k }, opts); }
|
||||
|
||||
if( zero_tensors ) {
|
||||
o.zero_();
|
||||
softmax_lse.fill_(-std::numeric_limits<float>::infinity());
|
||||
if (return_softmax) {s.zero_();}
|
||||
}
|
||||
|
||||
auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
|
||||
gen_, at::cuda::detail::getDefaultCUDAGenerator());
|
||||
|
||||
set_params_fprop(launch_params.params,
|
||||
batch_size,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
num_heads,
|
||||
head_size,
|
||||
q, k, v,
|
||||
cu_seqlens_q.data_ptr(),
|
||||
cu_seqlens_k.data_ptr(),
|
||||
o.data_ptr(),
|
||||
loop ? o_tmp.data_ptr() : nullptr,
|
||||
return_softmax ? s.data_ptr() : nullptr,
|
||||
softmax_lse.data_ptr(),
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal);
|
||||
|
||||
run_fmha_fprop(launch_params, /*configure=*/ true);
|
||||
// number of times random will be generated per thread, to offset philox counter in thc random
|
||||
// state
|
||||
int64_t counter_offset = launch_params.elts_per_thread;
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
|
||||
if( is_dropout ) {
|
||||
// See Note [Acquire lock when using random generators]
|
||||
std::lock_guard<std::mutex> lock(gen->mutex_);
|
||||
launch_params.params.philox_args = gen->philox_cuda_state(counter_offset);
|
||||
}
|
||||
|
||||
run_fmha_fprop(launch_params, /*configure=*/false);
|
||||
|
||||
std::vector<at::Tensor> result = {o, softmax_lse};
|
||||
if (return_softmax) {result.push_back(s);}
|
||||
return result;
|
||||
}
|
||||
} // namespace fmha
|
24
aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h
Normal file
24
aten/src/ATen/native/transformers/cuda/flash_attn/fmha_api.h
Normal file
@ -0,0 +1,24 @@
|
||||
#pragma once
|
||||
#include <cstddef>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
std::vector<at::Tensor>
|
||||
mha_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &k, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &v, // total_k x num_heads x head_size, total_k := \sum_{i=0}^{b} s_i
|
||||
const at::Tensor &cu_seqlens_q, // b+1
|
||||
const at::Tensor &cu_seqlens_k, // b+1
|
||||
const int max_seqlen_q_,
|
||||
const int max_seqlen_k_,
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
const bool zero_tensors,
|
||||
const bool is_causal,
|
||||
const bool return_softmax,
|
||||
c10::optional<at::Generator> gen_);
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,722 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mask.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_kernel.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/softmax.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/epilogue.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/layout/layout.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_conversion.h>
|
||||
#include <third_party/cutlass/include/cutlass/arch/mma.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/warp/mma_tensor_op_tile_iterator.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm75.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core_sm80.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/warp/fragment_iterator_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/warp/tile_iterator_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_thread_map_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct Gemm_Q_K_base {
|
||||
using Smem_O = fmha::FMHAEpilogue<typename Kernel_traits::MmaCorePV>;
|
||||
using WarpMma = typename Kernel_traits::MmaCoreQK::MmaTensorOp;
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
|
||||
static constexpr size_t SMEM_BYTES_SOFTMAX = Cta_tile_p::M * Cta_tile_p::WARPS_N * sizeof(float) * 2;
|
||||
|
||||
__device__ inline Gemm_Q_K_base(char * smem_ptr_q, char * smem_ptr_k)
|
||||
: smem_q_ptr(smem_ptr_q)
|
||||
, smem_k_ptr(smem_ptr_k) {
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void load_q(int byte_offset=0) {
|
||||
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K});
|
||||
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32);
|
||||
iter_A.load(frag_q[0]);
|
||||
}
|
||||
|
||||
|
||||
__device__ inline void reload_q(int byte_offset=0) {
|
||||
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Cta_tile_p::M, Cta_tile_p::K});
|
||||
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(smem_q_ptr + byte_offset), layout_A}, threadIdx.x % 32);
|
||||
iter_A.load(frag_q[0]);
|
||||
}
|
||||
|
||||
typename WarpMma::FragmentA frag_q[2];
|
||||
char *smem_q_ptr;
|
||||
char *smem_k_ptr;
|
||||
};
|
||||
|
||||
template<typename Kernel_traits, bool K_in_regs>
|
||||
struct Gemm_Q_K : public Gemm_Q_K_base<Kernel_traits> {
|
||||
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Cta_tile_p = typename Base::Cta_tile_p;
|
||||
using Smem_O = typename Base::Smem_O;
|
||||
using WarpMma = typename Base::WarpMma;
|
||||
|
||||
static constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK;
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
// If V is stored in shared memory, we can't load K using the same shared memory.
|
||||
static_assert(Kernel_traits::V_IN_REGS);
|
||||
|
||||
static constexpr size_t SMEM_OFFSET_O = Kernel_traits::BYTES_PER_SMEM_Q;
|
||||
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
|
||||
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
|
||||
|
||||
// Q | K / V
|
||||
// | O | SOFTMAX
|
||||
static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q
|
||||
+ std::max((size_t)(SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K,
|
||||
sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX);
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_)
|
||||
: Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
|
||||
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_B.add_tile_offset({0, warp_idx});
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterations; ++ki ) {
|
||||
iter_B.load(frag_k[ki]);
|
||||
++iter_B;
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){
|
||||
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K});
|
||||
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32);
|
||||
++iter_A;
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterations; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if (ki + 1 < kIterations) { iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A; }
|
||||
// Do the math for the values already in registers.
|
||||
warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki], acc_p);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void reload_k(){
|
||||
// Noop.
|
||||
}
|
||||
|
||||
typename WarpMma::FragmentB frag_k[kIterations];
|
||||
};
|
||||
|
||||
|
||||
template<typename Kernel_traits>
|
||||
struct Gemm_Q_K<Kernel_traits, false> : public Gemm_Q_K_base<Kernel_traits> {
|
||||
using Base = Gemm_Q_K_base<Kernel_traits>;
|
||||
using Cta_tile_p = typename Base::Cta_tile_p;
|
||||
using Smem_O = typename Base::Smem_O;
|
||||
using WarpMma = typename Base::WarpMma;
|
||||
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = Kernel_traits::SHARE_SMEM_FOR_K_AND_V;
|
||||
static constexpr bool V_IN_REGS = Kernel_traits::V_IN_REGS;
|
||||
static_assert(V_IN_REGS || !SHARE_SMEM_FOR_K_AND_V);
|
||||
|
||||
static constexpr size_t SMEM_OFFSET_V = Kernel_traits::BYTES_PER_SMEM_Q + (SHARE_SMEM_FOR_K_AND_V ? 0 : Kernel_traits::BYTES_PER_SMEM_K);
|
||||
static constexpr size_t SMEM_OFFSET_O = SMEM_OFFSET_V + Kernel_traits::BYTES_PER_SMEM_V;
|
||||
static constexpr size_t SMEM_OFFSET_SOFTMAX = SMEM_OFFSET_O + sizeof(typename Smem_O::SharedStorage);
|
||||
|
||||
// If V_IN_REGS and SHARE_SMEM_FOR_K_AND_V: Q | K/V | O | SOFTMAX
|
||||
// If !V_IN_REGS (then !SHARE_SMEM_FOR_K_AND_V): Q | K | V | O | SOFTMAX
|
||||
static constexpr size_t SMEM_BYTES = Kernel_traits::BYTES_PER_SMEM_Q
|
||||
+ (SHARE_SMEM_FOR_K_AND_V ? 1 : 2) * Kernel_traits::BYTES_PER_SMEM_K
|
||||
+ sizeof(typename Smem_O::SharedStorage) + Base::SMEM_BYTES_SOFTMAX;
|
||||
|
||||
__device__ inline Gemm_Q_K(char * smem_)
|
||||
: Base(smem_, smem_ + Kernel_traits::BYTES_PER_SMEM_Q) {
|
||||
}
|
||||
|
||||
__device__ inline void load_k(){
|
||||
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
|
||||
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_B.add_tile_offset({0, warp_idx});
|
||||
iter_B.load(frag_k[0]);
|
||||
}
|
||||
|
||||
__device__ inline void operator()(WarpMma warp_mma, typename WarpMma::FragmentC &acc_p, int byte_offset_q=0){
|
||||
typename WarpMma::LayoutA layout_A = WarpMma::LayoutA::packed({Base::Cta_tile_p::M, Base::Cta_tile_p::K});
|
||||
typename WarpMma::IteratorA iter_A({reinterpret_cast<typename WarpMma::ElementA *>(Base::smem_q_ptr + byte_offset_q), layout_A}, threadIdx.x % 32);
|
||||
++iter_A;
|
||||
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
|
||||
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_B.add_tile_offset({0, warp_idx});
|
||||
++iter_B;
|
||||
|
||||
// Do this part of P^T = (Q * K^T)^T.
|
||||
constexpr int kIterations = WarpMma::Shape::kK / WarpMma::InstructionShape::kK;
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterations; ++ki ) {
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if (ki + 1 < kIterations) {
|
||||
iter_A.load(Base::frag_q[(ki + 1) % 2]); ++iter_A;
|
||||
iter_B.load(frag_k[(ki + 1) % 2]); ++iter_B;
|
||||
}
|
||||
// Do the math for the values already in registers.
|
||||
warp_mma(acc_p, Base::frag_q[ki % 2], frag_k[ki % 2], acc_p);
|
||||
}
|
||||
}
|
||||
__device__ inline void reload_k(){
|
||||
typename WarpMma::LayoutB layout_B = WarpMma::LayoutB::packed({Cta_tile_p::K, Cta_tile_p::N});
|
||||
typename WarpMma::IteratorB iter_B({reinterpret_cast<typename WarpMma::ElementB *>(Base::smem_k_ptr), layout_B}, threadIdx.x % 32);
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_B.add_tile_offset({0, warp_idx});
|
||||
iter_B.load(frag_k[0]);
|
||||
}
|
||||
|
||||
typename WarpMma::FragmentB frag_k[2];
|
||||
};
|
||||
|
||||
template<typename Kernel_traits>
|
||||
constexpr size_t get_dynamic_smem_size(){
|
||||
return Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>::SMEM_BYTES;
|
||||
}
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, bool Is_first, bool Is_last, typename Params, typename Prng>
|
||||
inline __device__ void device_1xN_(const Params ¶ms, const int bidb, const int bidh, int begin, int steps, Prng &ph0, Prng &ph1, const int loop_step_idx) {
|
||||
|
||||
// The description of the CTA tile for the 1st batched GEMM.
|
||||
using Cta_tile_p = typename Kernel_traits::Cta_tile_p;
|
||||
// The description of the CTA tile for the 2nd batched GEMM.
|
||||
using Cta_tile_o = typename Kernel_traits::Cta_tile_o;
|
||||
|
||||
// The MMA tile for the 1st GEMM.
|
||||
using Mma_tile_p = fmha::Hmma_tile<Cta_tile_p>;
|
||||
// The MMA tile for the 2nd GEMM.
|
||||
using Mma_tile_o = fmha::Hmma_tile<Cta_tile_o>;
|
||||
|
||||
using InstructionShape = typename Kernel_traits::MmaInstructionShape;
|
||||
using Element = typename Kernel_traits::Element;
|
||||
using ElementAccum = typename Kernel_traits::ElementAccum;
|
||||
|
||||
using ThreadblockShapeQK = typename Kernel_traits::ThreadblockShapeQK;
|
||||
using LayoutQ = typename Kernel_traits::LayoutQ;
|
||||
using LayoutK = typename Kernel_traits::LayoutK;
|
||||
using LayoutP = typename Kernel_traits::LayoutP;
|
||||
using MmaCoreQK = typename Kernel_traits::MmaCoreQK;
|
||||
using WarpMmaQK = typename MmaCoreQK::MmaTensorOp;
|
||||
using SmemLayoutQ = typename MmaCoreQK::SmemLayoutA;
|
||||
using SmemLayoutK = typename MmaCoreQK::SmemLayoutB;
|
||||
using SmemIteratorQ = typename MmaCoreQK::SmemIteratorA;
|
||||
using SmemIteratorK = typename MmaCoreQK::SmemIteratorB;
|
||||
|
||||
using ThreadblockShapePV = typename Kernel_traits::ThreadblockShapePV;
|
||||
using LayoutV = typename Kernel_traits::LayoutV;
|
||||
using LayoutO = typename Kernel_traits::LayoutO;
|
||||
using MmaCorePV = typename Kernel_traits::MmaCorePV;
|
||||
using WarpMmaPV = typename MmaCorePV::MmaTensorOp;
|
||||
using WarpIteratorV = typename WarpMmaPV::IteratorB;
|
||||
using SmemLayoutV = typename MmaCorePV::SmemLayoutB;
|
||||
using SmemIteratorV = typename MmaCorePV::SmemIteratorB;
|
||||
constexpr int kIterationsPV = WarpMmaPV::Shape::kK / WarpMmaPV::InstructionShape::kK;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
// Copy from mma_piplined_testbed.h
|
||||
using GmemIteratorQ = typename Kernel_traits::GmemIteratorQ;
|
||||
// The global memory tile to load K.
|
||||
using GmemIteratorK = typename Kernel_traits::GmemIteratorK;
|
||||
// The global memory tile to load V.
|
||||
using GmemIteratorV = typename Kernel_traits::GmemIteratorV;
|
||||
// The global memory tile to store O.
|
||||
using GmemIteratorO = typename fmha::FMHAEpilogue<MmaCorePV>::GmemIterator;
|
||||
using GmemIteratorOAccum = typename fmha::FMHAEpilogue<MmaCorePV>::GmemIteratorAccum;
|
||||
|
||||
using Gmem_tile_s = typename Kernel_traits::Gmem_tile_s;
|
||||
|
||||
using Gmem_softmax_sum = typename Kernel_traits::Gmem_softmax_sum;
|
||||
|
||||
using Smem_softmax_lse = typename Kernel_traits::Smem_softmax_lse;
|
||||
|
||||
using Gemm1 = Gemm_Q_K<Kernel_traits, Kernel_traits::K_IN_REGS>;
|
||||
|
||||
using Softmax = fmha::Softmax<Cta_tile_p, Kernel_traits>;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const BlockInfoPadded<Kernel_traits::THREADS> binfo(params, bidb, bidh, tidx);
|
||||
if( binfo.stop_early(loop_step_idx * Cta_tile_p::N) ) return;
|
||||
|
||||
Gemm1 gemm_q_k(smem_);
|
||||
// Allocate the global memory tile loader for S.
|
||||
Gmem_tile_s gmem_s(params, binfo, tidx);
|
||||
Gmem_softmax_sum gmem_softmax_lse(params.softmax_lse_ptr, params, tidx);
|
||||
|
||||
// Wind gmem tiles to the correct position.
|
||||
static_assert(Cta_tile_p::N % Cta_tile_p::M == 0);
|
||||
const int begin_og = begin;
|
||||
begin = Is_causal ? std::max(begin, loop_step_idx * Cta_tile_p::N / Cta_tile_p::M) : begin;
|
||||
const int steps_og = steps;
|
||||
steps -= begin - begin_og;
|
||||
if (Return_softmax) { gmem_s.move(begin); }
|
||||
gmem_softmax_lse.move(begin);
|
||||
|
||||
fmha::Mask<Cta_tile_p, Is_causal> mask(binfo, tidx, loop_step_idx);
|
||||
|
||||
// The base pointer of smem_v;
|
||||
char *smem_v_addr = &smem_[Gemm1::SMEM_OFFSET_V];
|
||||
|
||||
// Allocate the shared memory tile loader for V. We use the same as K so be careful!!!
|
||||
|
||||
SmemLayoutQ layout_Q = SmemLayoutQ::packed({ThreadblockShapeQK::kM, ThreadblockShapeQK::kK});
|
||||
SmemIteratorQ smem_q({reinterpret_cast<Element *>(smem_), layout_Q}, tidx);
|
||||
SmemLayoutK layout_K = SmemLayoutK::packed({ThreadblockShapeQK::kK, ThreadblockShapeQK::kN});
|
||||
SmemIteratorK smem_k({reinterpret_cast<Element *>(smem_ + Kernel_traits::BYTES_PER_SMEM_Q), layout_K}, tidx);
|
||||
SmemLayoutV layout_V = SmemLayoutV::packed({ThreadblockShapePV::kK, ThreadblockShapePV::kN});
|
||||
// SmemIterator stores to smem and WarpIterator loads from smem
|
||||
SmemIteratorV smem_v({reinterpret_cast<Element *>(smem_v_addr), layout_V}, tidx);
|
||||
WarpIteratorV iter_V({reinterpret_cast<Element *>(smem_v_addr), layout_V}, threadIdx.x % 32);
|
||||
|
||||
// Allocate the shared memory tile loader for O. We use the same as K so be careful!!!
|
||||
using Smem_O = fmha::FMHAEpilogue<MmaCorePV>;
|
||||
Smem_O smem_o(&smem_[Gemm1::SMEM_OFFSET_O], tidx);
|
||||
|
||||
// Allocate the global memory tile loader for Q.
|
||||
// cutlass::transform::threadblock::PredicatedTileIterator deals with seqlen not divisible
|
||||
// by 16 in a different way than we want. If the seqlen_q is 36, the first iteration would
|
||||
// load 4 rows and the next two iterations would load 16 rows each. Instead we round the
|
||||
// actual_seqlen_q to be multiple of 16, then change the mask in the last iteration, so
|
||||
// that in this case we would load 16, 16, 4.
|
||||
LayoutQ gmem_layout_Q(params.q_row_stride_in_elts);
|
||||
typename GmemIteratorQ::Params gmem_Q_params(gmem_layout_Q);
|
||||
const uint32_t row_offset_q = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.q_row_stride_in_elts + binfo.bidh * params.q_head_stride_in_elts;
|
||||
const int actual_seqlen_q = binfo.actual_seqlen_q - begin * ThreadblockShapeQK::kM;
|
||||
const int seqlen_q_remainder = actual_seqlen_q % ThreadblockShapeQK::kM;
|
||||
const int extent_q = ((actual_seqlen_q <= ThreadblockShapeQK::kM) || (seqlen_q_remainder == 0)) ? actual_seqlen_q : actual_seqlen_q + ThreadblockShapeQK::kM - seqlen_q_remainder;
|
||||
GmemIteratorQ gmem_q(gmem_Q_params,
|
||||
reinterpret_cast<Element *>(params.q_ptr) + row_offset_q,
|
||||
{extent_q, params.d},
|
||||
tidx);
|
||||
|
||||
// Allocate the global memory tile loader for K.
|
||||
LayoutK gmem_layout_K(params.k_row_stride_in_elts);
|
||||
typename GmemIteratorK::Params gmem_K_params(gmem_layout_K);
|
||||
const uint32_t row_offset_k = (binfo.sum_s_k + loop_step_idx * ThreadblockShapeQK::kN) * params.k_row_stride_in_elts + binfo.bidh * params.k_head_stride_in_elts;
|
||||
const int extent_k = min(binfo.actual_seqlen_k - loop_step_idx * ThreadblockShapeQK::kN, ThreadblockShapeQK::kN);
|
||||
GmemIteratorK gmem_k(gmem_K_params,
|
||||
reinterpret_cast<Element *>(params.k_ptr) + row_offset_k,
|
||||
{params.d, extent_k},
|
||||
tidx);
|
||||
|
||||
// Allocate the global memory tile loader for V.
|
||||
LayoutV gmem_layout_V(params.v_row_stride_in_elts);
|
||||
typename GmemIteratorV::Params gmem_V_params(gmem_layout_V);
|
||||
const uint32_t row_offset_v = (binfo.sum_s_k + loop_step_idx * ThreadblockShapePV::kK) * params.v_row_stride_in_elts + binfo.bidh * params.v_head_stride_in_elts;
|
||||
// extent_v is the same as extent_k
|
||||
GmemIteratorV gmem_v(gmem_V_params,
|
||||
reinterpret_cast<Element *>(params.v_ptr) + row_offset_v,
|
||||
{extent_k, params.d},
|
||||
tidx);
|
||||
|
||||
// Allocate the global memory tile loader for O.
|
||||
LayoutO gmem_layout_O(params.o_row_stride_in_elts);
|
||||
typename GmemIteratorO::Params gmem_O_params(gmem_layout_O);
|
||||
const uint32_t row_offset_o = (binfo.sum_s_q + begin * ThreadblockShapeQK::kM) * params.o_row_stride_in_elts + binfo.bidh * params.o_head_stride_in_elts;
|
||||
GmemIteratorO gmem_o(gmem_O_params,
|
||||
reinterpret_cast<Element *>(params.o_ptr) + row_offset_o,
|
||||
{actual_seqlen_q, params.d},
|
||||
tidx);
|
||||
|
||||
typename GmemIteratorOAccum::Params gmem_Oaccum_params(gmem_layout_O);
|
||||
GmemIteratorOAccum gmem_o_accum(gmem_Oaccum_params,
|
||||
reinterpret_cast<ElementAccum *>(params.o_tmp_ptr) + row_offset_o,
|
||||
{actual_seqlen_q, params.d},
|
||||
tidx);
|
||||
|
||||
// Create the object to do the softmax.
|
||||
Softmax softmax(params, &smem_[Gemm1::SMEM_OFFSET_SOFTMAX], tidx);
|
||||
|
||||
Smem_softmax_lse smem_softmax_lse(reinterpret_cast<float *>(&smem_[Gemm1::SMEM_BYTES]));
|
||||
|
||||
if (!Is_first) {
|
||||
if (Return_softmax) { gmem_s.move(loop_step_idx * steps_og); }
|
||||
}
|
||||
|
||||
if (!Is_first) { __syncthreads(); }
|
||||
|
||||
// Trigger the loads for V.
|
||||
typename GmemIteratorV::Fragment gmem_frag_v;
|
||||
gmem_frag_v.clear();
|
||||
gmem_v.load(gmem_frag_v);
|
||||
|
||||
// Trigger the loads for Q.
|
||||
typename GmemIteratorQ::Fragment gmem_frag_q;
|
||||
gmem_frag_q.clear();
|
||||
gmem_q.load(gmem_frag_q);
|
||||
|
||||
// Trigger the loads for K.
|
||||
typename GmemIteratorK::Fragment gmem_frag_k;
|
||||
gmem_frag_k.clear();
|
||||
gmem_k.load(gmem_frag_k);
|
||||
|
||||
float p_prev_lse[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
|
||||
}
|
||||
|
||||
// Commit the data for Q and V to shared memory.
|
||||
smem_v.store(gmem_frag_v);
|
||||
smem_q.store(gmem_frag_q);
|
||||
|
||||
// Commit the data for K to shared memory.
|
||||
if( !Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
smem_k.store(gmem_frag_k);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// Load the fragments for Q.
|
||||
gemm_q_k.load_q();
|
||||
|
||||
// Load the fragments for V. We keep the data in registers during the entire
|
||||
// kernel. copied from mma_pipelined.h
|
||||
const int warp_idx = threadIdx.x / 32;
|
||||
iter_V.add_tile_offset({kIterationsPV * warp_idx, 0});
|
||||
typename WarpIteratorV::Fragment frag_v[kIterationsPV];
|
||||
static_assert(WarpIteratorV::Fragment::kStorageElements == 4 * Mma_tile_o::MMAS_N || WarpIteratorV::Fragment::kStorageElements == 2 * Mma_tile_o::MMAS_N );
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterationsPV; ++ki ) {
|
||||
iter_V.load(frag_v[ki]);
|
||||
++iter_V;
|
||||
}
|
||||
|
||||
// Commit the data for K to shared memory if it has not been done already.
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V ) {
|
||||
// Make sure we are done loading the fragments for K.
|
||||
__syncthreads();
|
||||
|
||||
// Commit the data to shared memory for K.
|
||||
smem_k.store(gmem_frag_k);
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Load the fragments for K.
|
||||
gemm_q_k.load_k();
|
||||
|
||||
// Load over the entire sequence length.
|
||||
for( int l = 0; l < steps; l++ ) {
|
||||
if((begin + l) * Cta_tile_p::M >= binfo.actual_seqlen_q) break;
|
||||
|
||||
// Declare the accumulators for the 1st gemm.
|
||||
WarpMmaQK mma_qk;
|
||||
typename WarpMmaQK::FragmentC acc_p;
|
||||
acc_p.clear();
|
||||
|
||||
// Do this part of P = Q * K^T.
|
||||
gemm_q_k(mma_qk, acc_p);
|
||||
|
||||
typename Smem_O::OutputFragment out[Smem_O::kIterationsStore];
|
||||
static_assert(GmemIteratorOAccum::kIterations == Smem_O::kIterationsStore);
|
||||
static_assert(GmemIteratorO::kIterations == Smem_O::kIterationsStore);
|
||||
if (!Is_first) {
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
|
||||
gmem_o_accum.load(out[iter]);
|
||||
gmem_o_accum.move();
|
||||
}
|
||||
}
|
||||
|
||||
// Trigger the load for the next Q values.
|
||||
if( l < steps - 1) {
|
||||
++gmem_q;
|
||||
// If actual_seqlen_q is not a multiple of 16, we change the mask in the last iteration
|
||||
// to load the "residue" tile.
|
||||
if ((l + 1 == steps - 1) && (actual_seqlen_q % ThreadblockShapeQK::kM != 0)) {
|
||||
// TODO: this probably only works for head_dim = 64 and head_dim = 128, which is
|
||||
// what we have right now. Maybe for head_dim = 32 or 96, this could be different.
|
||||
const int row_idx = tidx / (GmemIteratorQ::Shape::kColumn / GmemIteratorQ::Fragment::kElements);
|
||||
if (row_idx >= actual_seqlen_q - (l + 1) * ThreadblockShapeQK::kM) {
|
||||
gmem_q.clear_mask();
|
||||
}
|
||||
}
|
||||
gmem_q.load(gmem_frag_q);
|
||||
}
|
||||
|
||||
// Load the mask for that iteration.
|
||||
mask.load(begin + l);
|
||||
|
||||
// Convert from the accumulator type to FP32 for Softmax.
|
||||
softmax.unpack_noscale(acc_p);
|
||||
|
||||
// Apply the mask.
|
||||
softmax.apply_mask(mask);
|
||||
|
||||
if( Kernel_traits::SHARE_SMEM_FOR_K_AND_V && l == 0 ) {
|
||||
// if we share K and V, it could be that V was not fully read yet but we write into smem for reduction
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Compute the max.
|
||||
float p_max[Mma_tile_p::MMAS_M * 2];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.store_pair(p_prev_lse);
|
||||
for (int mi = 0; mi < Mma_tile_p::MMAS_M * 2; mi++) { p_max[mi] = p_prev_lse[mi] / params.scale_bmm1; }
|
||||
}
|
||||
|
||||
// Trigger the load for the next LSE values.
|
||||
if( l < steps - 1) {
|
||||
if (!Is_first) {
|
||||
gmem_softmax_lse.load_next(reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M * 2]>(p_prev_lse));
|
||||
}
|
||||
}
|
||||
|
||||
softmax.template reduce_max</*zero_init=*/Is_first>(p_max);
|
||||
|
||||
// Compute the exponential value.
|
||||
softmax.scale_apply_exp(p_max, params.scale_bmm1);
|
||||
|
||||
// We don't finalize the sum reduction here, as that would incur an extra sync_threads().
|
||||
// Instead, we reduce the sum from each warp, write to smem, then wait until the sync_threads()
|
||||
// from storing acc_o. Then we read the sum of each warp from smem and finalize the reduction.
|
||||
// As a consequence, we don't scale acc_p by the inverse sum, we scale the output by the inverse sum.
|
||||
// Compute the sum.
|
||||
float p_sum[Mma_tile_p::MMAS_M * 2];
|
||||
// softmax.reduce_sum(p_sum);
|
||||
softmax.reduce_sum_before_sync_(p_sum);
|
||||
|
||||
constexpr bool encode_dropout_in_sign_bit = Return_softmax;
|
||||
if (Is_dropout) {
|
||||
softmax.template apply_dropout_16bits<encode_dropout_in_sign_bit>(ph0, ph1, params.p_dropout_in_uint16_t);
|
||||
}
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == Mma_tile_p::MMAS_M);
|
||||
static_assert(Mma_tile_o::MMAS_K == Mma_tile_p::MMAS_N);
|
||||
softmax.pack_noconvert(acc_p);
|
||||
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(acc_p)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_p;
|
||||
auto frag_p = convert_p(acc_p);
|
||||
|
||||
if (Return_softmax) {
|
||||
gmem_s.store(reinterpret_cast<const cutlass::Array<Element, 8>(&)[Mma_tile_o::MMAS_K][Mma_tile_o::MMAS_M]>(frag_p), mask);
|
||||
gmem_s.move();
|
||||
}
|
||||
|
||||
// Commit the values for Q into shared memory.
|
||||
if (l < steps - 1) { smem_q.store(gmem_frag_q); }
|
||||
|
||||
if (Is_dropout && encode_dropout_in_sign_bit) {
|
||||
cutlass::epilogue::thread::ReLu<decltype(frag_p)> relu;
|
||||
frag_p = relu(frag_p);
|
||||
}
|
||||
|
||||
// Declare the accumulators for the 2nd gemm.
|
||||
WarpMmaPV mma_pv;
|
||||
typename WarpMmaPV::FragmentC acc_o;
|
||||
static_assert(WarpMmaPV::FragmentC::kElements == Mma_tile_o::MMAS_M * Mma_tile_o::MMAS_N * 8);
|
||||
acc_o.clear();
|
||||
|
||||
// For some reason, WarpMmaPV::FragmentA has length K * N * (8|4) instead of just N * (8|4).
|
||||
// We have to first cast frag_p to be array of k x (N * (8|4)), then cast each row to be
|
||||
// an array of WarpMmaPV::FragmentA (which is what mma_pv expects).
|
||||
static_assert(decltype(frag_p)::kElements == kIterationsPV * Mma_tile_o::MMAS_M * WarpMmaPV::FragmentA::kElements);
|
||||
const auto frag_p_reshaped = reinterpret_cast<const cutlass::Array<Element, WarpMmaPV::FragmentA::kElements> (&)[kIterationsPV]>(frag_p);
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < kIterationsPV; ++ki ) {
|
||||
mma_pv(acc_o, reinterpret_cast<const typename WarpMmaPV::FragmentA(&)>(frag_p_reshaped[ki]), frag_v[ki], acc_o);
|
||||
}
|
||||
// Swizzle the elements and do the final reduction.
|
||||
smem_o.store(acc_o);
|
||||
|
||||
// The mapping from tidx to rows changes between the softmax and the
|
||||
// O-reduction. So we recalculate the max.
|
||||
using OutputTileThreadMap = typename Smem_O::OutputTileThreadMap;
|
||||
constexpr int kOutputRowsPerThread = OutputTileThreadMap::Iterations::kRow * Smem_O::kIterationsStore;
|
||||
float p_max_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
|
||||
int rows[kOutputRowsPerThread];
|
||||
cutlass::MatrixCoord output_thread_offset = OutputTileThreadMap::initial_offset(tidx);
|
||||
const int output_thread_start_row = output_thread_offset.row();
|
||||
const int output_thread_start_column = output_thread_offset.column();
|
||||
for (int iter = 0; iter < Smem_O::kIterationsStore; ++iter) {
|
||||
for (int row = 0; row < OutputTileThreadMap::Iterations::kRow; ++row) {
|
||||
rows[iter * OutputTileThreadMap::Iterations::kRow + row] = output_thread_start_row + iter * OutputTileThreadMap::Shape::kRow + row;
|
||||
}
|
||||
}
|
||||
|
||||
softmax.reduce_max_after_sync_(p_max_o, rows);
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
p_max_o[jj][0] *= params.scale_bmm1;
|
||||
}
|
||||
float p_prev_scale_o[kOutputRowsPerThread];
|
||||
if (!Is_first) {
|
||||
smem_softmax_lse.load(p_prev_scale_o, rows);
|
||||
}
|
||||
|
||||
// Make sure the data is in shared memory.
|
||||
__syncthreads();
|
||||
|
||||
static_assert(Mma_tile_o::MMAS_M == 1);
|
||||
float p_sum_o[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
|
||||
softmax.reduce_sum_after_sync_(p_sum_o, rows);
|
||||
if (!Is_first) {
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
p_prev_scale_o[jj] = expf(p_prev_scale_o[jj] - p_max_o[jj][0]);
|
||||
p_sum_o[jj][0] += p_prev_scale_o[jj];
|
||||
}
|
||||
}
|
||||
|
||||
float p_sum_log[kOutputRowsPerThread][Mma_tile_o::MMAS_M];
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
p_sum_log[jj][0] = (sum == 0.f || sum != sum) ? -INFINITY : p_max_o[jj][0] + __logf(sum);
|
||||
if (output_thread_start_column == 0) {
|
||||
gmem_softmax_lse.store_row(
|
||||
reinterpret_cast<uint32_t(&)[Mma_tile_p::MMAS_M]>(p_sum_log[jj]), rows[jj]);
|
||||
}
|
||||
}
|
||||
gmem_softmax_lse.move();
|
||||
|
||||
// Load from shared memory.
|
||||
using ArrayTypeO = cutlass::Array<ElementAccum, OutputTileThreadMap::kElementsPerAccess>;
|
||||
static_assert(OutputTileThreadMap::kElementsPerAccess * kOutputRowsPerThread == Smem_O::kIterationsStore * Smem_O::OutputFragment::kElements);
|
||||
cutlass::multiplies<ArrayTypeO> multiply_fragments;
|
||||
if (!Is_first) {
|
||||
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
out_reshaped[jj] = multiply_fragments(out_reshaped[jj], p_prev_scale_o[jj]);
|
||||
}
|
||||
}
|
||||
smem_o.template load</*zero_init=*/Is_first>(out, tidx);
|
||||
|
||||
const bool is_final_write =
|
||||
Is_last
|
||||
|| ((loop_step_idx + 1) * Cta_tile_p::N >= binfo.actual_seqlen_k)
|
||||
|| ((Is_causal) && ((begin + l) * Cta_tile_p::M < (loop_step_idx + 1) * Cta_tile_p::N));
|
||||
auto out_reshaped = reinterpret_cast<ArrayTypeO (&)[kOutputRowsPerThread]>(out);
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < kOutputRowsPerThread; jj++) {
|
||||
float sum = p_sum_o[jj][0];
|
||||
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
|
||||
if (Is_dropout && is_final_write) {
|
||||
inv_sum *= params.rp_dropout;
|
||||
}
|
||||
out_reshaped[jj] = multiply_fragments(out_reshaped[jj], inv_sum);
|
||||
}
|
||||
|
||||
// Output the values.
|
||||
if (is_final_write) {
|
||||
typename GmemIteratorO::Fragment out_converted;
|
||||
cutlass::NumericArrayConverter<Element, ElementAccum, decltype(out_converted)::kElements, cutlass::FloatRoundStyle::round_to_nearest> convert_o;
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < GmemIteratorO::kIterations; ++iter) {
|
||||
out_converted = convert_o(out[iter]);
|
||||
gmem_o.store(out_converted);
|
||||
gmem_o.move();
|
||||
}
|
||||
// We also need to move gmem_o_accum. For example, if Is_causal=true and seqlen=512,
|
||||
// in the first loop, we write the first 256 rows to gmem_o and the last 256 rows to gmem_o_accum.
|
||||
if (Is_first && !Is_last) { gmem_o_accum.move(GmemIteratorOAccum::kIterations); }
|
||||
} else {
|
||||
if (!Is_first) { gmem_o_accum.move(-GmemIteratorOAccum::kIterations); }
|
||||
#pragma unroll
|
||||
for (int iter = 0; iter < GmemIteratorOAccum::kIterations; ++iter) {
|
||||
gmem_o_accum.store(out[iter]);
|
||||
gmem_o_accum.move();
|
||||
}
|
||||
}
|
||||
|
||||
gemm_q_k.reload_k();
|
||||
|
||||
// Trigger the load from shared memory for the next series of Q values.
|
||||
if(l < steps - 1) {
|
||||
gemm_q_k.reload_q();
|
||||
}
|
||||
|
||||
} // Outer loop over the sequence length.
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax, typename Params>
|
||||
inline __device__ void device_1xN_loop(const Params ¶ms) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The thread index.
|
||||
const int tidx = threadIdx.x;
|
||||
|
||||
const int tidx_global = (bidb * params.h + bidh) * blockDim.x * 2 + tidx;
|
||||
auto seeds = at::cuda::philox::unpack(params.philox_args);
|
||||
// We use 2 Philox generators to match the dropout pattern in the backward pass.
|
||||
// Forward pass uses 128 threads while backward pass uses 256 threads, so each thread
|
||||
// in the forward pass is simulating the droout pattern of 2 threads in the backward pass.
|
||||
Philox ph0(std::get<0>(seeds), tidx_global, std::get<1>(seeds));
|
||||
Philox ph1(std::get<0>(seeds), tidx_global + blockDim.x, std::get<1>(seeds));
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
const int STEPS = (params.seqlen_q + M - 1) / M;
|
||||
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
if (params.seqlen_k == blocksize_c) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
|
||||
} else {
|
||||
const int max_loop_steps = (params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, true, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, 0);
|
||||
for (int loop_step_idx = 1; loop_step_idx < max_loop_steps - 1; loop_step_idx++) {
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, false>(params, bidb, bidh, 0, STEPS, ph0, ph1, loop_step_idx);
|
||||
}
|
||||
fmha::device_1xN_<Kernel_traits, Is_dropout, Is_causal, Return_softmax, false, true>(params, bidb, bidh, 0, STEPS, ph0, ph1, max_loop_steps - 1);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,134 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/static_switch.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/kernel_traits.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_fprop_kernel_1xN.h>
|
||||
|
||||
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Return_softmax>
|
||||
__global__ void fmha_fprop_loop_kernel(FMHA_fprop_params params) {
|
||||
fmha::device_1xN_loop<Kernel_traits, Is_dropout, Is_causal, Return_softmax>(params);
|
||||
}
|
||||
|
||||
template<typename Kernel_traits>
|
||||
void run_fmha_loop_(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
constexpr int blocksize_c = Kernel_traits::Cta_tile_p::N;
|
||||
const int loop_steps = (launch_params.params.seqlen_k + blocksize_c - 1) / blocksize_c;
|
||||
|
||||
if (configure) {
|
||||
using Mma_tile_p = fmha::Hmma_tile<typename Kernel_traits::Cta_tile_p>;
|
||||
constexpr int M = Kernel_traits::Cta_tile_p::M;
|
||||
size_t STEPS = (launch_params.params.seqlen_q + M - 1) / M;
|
||||
constexpr size_t MMAS_M = Mma_tile_p::MMAS_M;
|
||||
constexpr size_t MMAS_N = Mma_tile_p::MMAS_N;
|
||||
size_t elts_per_head = STEPS * MMAS_M * MMAS_N * 8 * loop_steps;
|
||||
launch_params.elts_per_thread = elts_per_head;
|
||||
return;
|
||||
}
|
||||
|
||||
constexpr size_t smem_size_softmax_lse = Kernel_traits::Smem_softmax_lse::BYTES_PER_TILE;
|
||||
// Don't need smem_size_softmax_lse if we're not looping
|
||||
const size_t smem_size = fmha::get_dynamic_smem_size<Kernel_traits>()
|
||||
+ (loop_steps > 1 ? smem_size_softmax_lse : 0);
|
||||
// printf("smem_size = %d\n", smem_size);
|
||||
|
||||
// Work-around for gcc 7. It doesn't like nested BOOL_SWITCH.
|
||||
// https://github.com/kokkos/kokkos-kernels/issues/349
|
||||
// https://github.com/HazyResearch/flash-attention/issues/21
|
||||
BOOL_SWITCH(launch_params.is_dropout, IsDropoutConst, [&] {
|
||||
auto kernel = launch_params.params.is_causal
|
||||
? (launch_params.return_softmax
|
||||
? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, true, true>
|
||||
: &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, true, false>)
|
||||
: (launch_params.return_softmax
|
||||
? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, false, true>
|
||||
: &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConst, false, false>);
|
||||
// constexpr bool IsDropoutConstTmp = false;
|
||||
// auto kernel = launch_params.params.is_causal
|
||||
// ? (launch_params.return_softmax
|
||||
// ? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, true, true>
|
||||
// : &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, true, false>)
|
||||
// : (launch_params.return_softmax
|
||||
// ? &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, false, true>
|
||||
// : &fmha_fprop_loop_kernel<Kernel_traits, IsDropoutConstTmp, false, false>);
|
||||
if( smem_size >= 48L * 1024 ) {
|
||||
FMHA_CHECK_CUDA(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
dim3 grid(launch_params.params.b, launch_params.params.h);
|
||||
kernel<<<grid, Kernel_traits::THREADS, smem_size, launch_params.stream>>>(
|
||||
launch_params.params);
|
||||
FMHA_CHECK_CUDA(cudaPeekAtLastError());
|
||||
});
|
||||
}
|
||||
|
||||
void run_fmha_fprop(Launch_params<FMHA_fprop_params> &launch_params,
|
||||
const bool configure) {
|
||||
BOOL_SWITCH(launch_params.params.is_bf16, IsBf16Const, [&] {
|
||||
using elem_type = std::conditional<IsBf16Const, cutlass::bfloat16_t, cutlass::half_t>::type;
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (launch_params.params.d <= 64) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
// TD [2022-08-20]: One might expect that not sharing the smem between K & V
|
||||
// could be faster, but seems like it's the same speed.
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else if( launch_params.params.seqlen_k >= 256 ) {
|
||||
if (dprops->major == 8 && dprops->minor >= 0) {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else if (dprops->major == 7 && dprops->minor == 5) {
|
||||
if (launch_params.is_dropout) { // Need to use the same block size as backward
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else {
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 64, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if (launch_params.params.d <= 128) {
|
||||
if( launch_params.params.seqlen_k == 128 ) {
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else {
|
||||
if (dprops->major == 8 && dprops->minor == 0 && !launch_params.is_dropout) {
|
||||
// TD [2022-06-05] Keep K in smem to reduce register spilling
|
||||
// Gives about 6% speedup compared to using block size 128.
|
||||
using Kernel_traits = FMHA_kernel_traits<256, 128, 16, 1, 4, 0x18u, elem_type>;
|
||||
// using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
} else { // Need to use the same block size as backward
|
||||
using Kernel_traits = FMHA_kernel_traits<128, 128, 16, 1, 4, 0x08u, elem_type>;
|
||||
run_fmha_loop_<Kernel_traits>(launch_params, configure);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
@ -0,0 +1,71 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS_PER_CTA>
|
||||
struct BlockInfoPadded {
|
||||
|
||||
template<typename Params>
|
||||
__device__ BlockInfoPadded(const Params ¶ms,
|
||||
const int bidb,
|
||||
const int bidh,
|
||||
const int tidx)
|
||||
: bidb(bidb), bidh(bidh), h(params.h) {
|
||||
|
||||
// The block index.
|
||||
sum_s_k = params.cu_seqlens_k[bidb];
|
||||
actual_seqlen_k = params.cu_seqlens_k[bidb + 1] - sum_s_k;
|
||||
sum_s_q = params.cu_seqlens_q[bidb];
|
||||
actual_seqlen_q = params.cu_seqlens_q[bidb + 1] - sum_s_q;
|
||||
|
||||
tidx_global = (bidb * params.h + bidh) * THREADS_PER_CTA + tidx;
|
||||
}
|
||||
|
||||
__device__ bool stop_early(const int start_col = 0) const {
|
||||
return actual_seqlen_k <= start_col;
|
||||
}
|
||||
|
||||
uint32_t actual_seqlen_q;
|
||||
uint32_t actual_seqlen_k;
|
||||
uint32_t sum_s_q;
|
||||
uint32_t sum_s_k;
|
||||
uint32_t bidh;
|
||||
uint32_t bidb;
|
||||
uint32_t tidx_global;
|
||||
uint32_t h;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,52 @@
|
||||
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
#define FMHA_CHECK_CUDA( call ) \
|
||||
do { \
|
||||
cudaError_t status_ = call; \
|
||||
if( status_ != cudaSuccess ) { \
|
||||
fprintf( stderr, \
|
||||
"CUDA error (%s:%d): %s\n", \
|
||||
__FILE__, \
|
||||
__LINE__, \
|
||||
cudaGetErrorString( status_ ) ); \
|
||||
exit( 1 ); \
|
||||
} \
|
||||
} while( 0 )
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
95
aten/src/ATen/native/transformers/cuda/flash_attn/gemm.h
Normal file
95
aten/src/ATen/native/transformers/cuda/flash_attn/gemm.h
Normal file
@ -0,0 +1,95 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/utils.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/layout/layout.h>
|
||||
#include <third_party/cutlass/include/cutlass/arch/mma.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The number of rows in the CTA tile.
|
||||
int M_,
|
||||
// The number of cols in the CTA tile.
|
||||
int N_,
|
||||
// The number of elements in the the K dimension of the GEMM loop.
|
||||
int K_,
|
||||
// The number of rows of warps.
|
||||
int WARPS_M_,
|
||||
// The number of cols of warps.
|
||||
int WARPS_N_,
|
||||
// The number of warps in the K dimension of the GEMM loop.
|
||||
int WARPS_K_>
|
||||
struct Cta_tile_ {
|
||||
|
||||
static constexpr int M = M_, N = N_, K = K_;
|
||||
// The number of warps.
|
||||
static constexpr int WARPS_M = WARPS_M_, WARPS_N = WARPS_N_, WARPS_K = WARPS_K_;
|
||||
// The number of warps per CTA.
|
||||
static constexpr int WARPS_PER_CTA = WARPS_M * WARPS_N * WARPS_K;
|
||||
// The number of threads per warp.
|
||||
static constexpr int THREADS_PER_WARP = 32;
|
||||
// The number of threads per CTA.
|
||||
static constexpr int THREADS_PER_CTA = WARPS_PER_CTA * THREADS_PER_WARP;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile>
|
||||
struct Hmma_tile {
|
||||
// The number of elements computed with a single warp-MMA.
|
||||
static constexpr int M_PER_MMA = 16, N_PER_MMA = 16, K_PER_MMA = 16;
|
||||
|
||||
// The number of elements computed with a single CTA-MMA.
|
||||
static constexpr int M_PER_MMA_PER_CTA = M_PER_MMA * Cta_tile::WARPS_M,
|
||||
N_PER_MMA_PER_CTA = N_PER_MMA * Cta_tile::WARPS_N,
|
||||
K_PER_MMA_PER_CTA = K_PER_MMA * Cta_tile::WARPS_K;
|
||||
|
||||
// The number of MMAs needed to compute the GEMM.
|
||||
static constexpr int MMAS_M = DivUpConstexpr(Cta_tile::M, M_PER_MMA_PER_CTA),
|
||||
MMAS_N = DivUpConstexpr(Cta_tile::N, N_PER_MMA_PER_CTA),
|
||||
MMAS_K = DivUpConstexpr(Cta_tile::K, K_PER_MMA_PER_CTA);
|
||||
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int M, int N, int K, int WARPS_M, int WARPS_N, int WARPS_K>
|
||||
using Cta_tile_extd = Cta_tile_<M, N, K, WARPS_M, WARPS_N, WARPS_K>;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
272
aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h
Normal file
272
aten/src/ATen/native/transformers/cuda/flash_attn/gmem_tile.h
Normal file
@ -0,0 +1,272 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, int BYTES_PER_ELEMENT >
|
||||
struct Gmem_tile_mma_sd {
|
||||
|
||||
// The mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// Each STG stores 8 elements.
|
||||
static constexpr int BYTES_PER_STG = BYTES_PER_ELEMENT * 8;
|
||||
// The number of MMAs in the M dimension.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
// The number of MMAs in the N dimension.
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
// The number of rows computed per MMA per thread block.
|
||||
static constexpr int M_PER_MMA_PER_CTA = Mma_tile::M_PER_MMA_PER_CTA;
|
||||
// The number of cols computed per MMA per thread block.
|
||||
static constexpr int N_PER_MMA_PER_CTA = Mma_tile::N_PER_MMA_PER_CTA;
|
||||
// The number of threads per block.
|
||||
static constexpr int THREADS_PER_CTA = Cta_tile::THREADS_PER_CTA;
|
||||
// The size of each row in bytes. I.e. how many bytes are stored per STG.
|
||||
static constexpr int BYTES_PER_ROW = THREADS_PER_CTA * BYTES_PER_STG;
|
||||
// The distance between elements stored per loop (in bytes).
|
||||
static constexpr int LOOP_STRIDE_BYTES = MMAS_M * MMAS_N * BYTES_PER_ROW;
|
||||
|
||||
// The type of elements stored per STG.
|
||||
using Type = typename fmha::Uint_from_size_in_bytes<BYTES_PER_STG>::Type;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_tile_mma_sd(void *ptr, const Params ¶ms, const int bidb, const int bidh, const int tidx)
|
||||
: ptr_(static_cast<char *>(ptr)) {
|
||||
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// const size_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
const uint32_t block_stride_bytes = params.seqlen_q * params.seqlen_k * BYTES_PER_ELEMENT;
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_ += bidx * block_stride_bytes + tidx * BYTES_PER_STG;
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
inline __device__ void store(const Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::stg(ptr_ + offset, data);
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(Type &data, const int mi, const int ni) {
|
||||
// size_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
uint32_t offset = (mi * MMAS_N + ni) * BYTES_PER_ROW;
|
||||
fmha::ldg(data, ptr_ + offset);
|
||||
}
|
||||
|
||||
// Move to the next tile.
|
||||
inline __device__ void move(const int steps = 1) {
|
||||
ptr_ += LOOP_STRIDE_BYTES * steps;
|
||||
}
|
||||
|
||||
// The pointer in global memory.
|
||||
char *ptr_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename Cta_tile, typename Base = Gmem_tile_mma_sd<Cta_tile, sizeof(uint16_t)> >
|
||||
struct Gmem_tile_mma_s : public Base {
|
||||
|
||||
// The number of mmas in the vertical dimension.
|
||||
static constexpr int M = Base::MMAS_M;
|
||||
// The number of mmas in the horizontal dimension.
|
||||
static constexpr int N = Base::MMAS_N;
|
||||
// The type of the vectors stored by each STG.
|
||||
using Type = typename Base::Type;
|
||||
|
||||
// Ctor.
|
||||
template< typename Params, typename Block_info >
|
||||
inline __device__ Gmem_tile_mma_s(const Params ¶ms, const Block_info& binfo, const int tidx)
|
||||
: Base(params.s_ptr, params, binfo.bidb, binfo.bidh, tidx) {
|
||||
}
|
||||
|
||||
// Store to global memory.
|
||||
template<typename Mask, typename Fragment>
|
||||
inline __device__ void store(const Fragment (&frag)[N][M], const Mask& mask){
|
||||
static_assert(Fragment::kStorageElements == 4);
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
uint4 dst;
|
||||
dst.x = frag[ni][mi].raw_data()[0];
|
||||
dst.y = frag[ni][mi].raw_data()[2];
|
||||
dst.z = frag[ni][mi].raw_data()[1];
|
||||
dst.w = frag[ni][mi].raw_data()[3];
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::store(dst, mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
template<typename Mask>
|
||||
inline __device__ void load(uint4 (®s)[M][N], const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
regs[mi][ni] = make_uint4(0, 0, 0, 0);
|
||||
if( mask.any_valid(mi, ni) ) {
|
||||
Base::load(regs[mi][ni], mi, ni);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
// The dimensions of the tile computed by the CTA.
|
||||
typename Cta_tile
|
||||
>
|
||||
struct Gmem_summary_stats {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
|
||||
// The size of each element.
|
||||
static constexpr int BYTES_PER_ELEMENT = 4;
|
||||
static constexpr int BYTES_PER_MMA = (Cta_tile::THREADS_PER_WARP / 4) * 2 * BYTES_PER_ELEMENT;
|
||||
static constexpr int ROWS = Cta_tile::M;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Gmem_summary_stats(void *ptr, const Params ¶ms, const int tidx)
|
||||
: ptr_(reinterpret_cast<char *>(ptr)), tidx_(tidx) {
|
||||
|
||||
// The block index for the batch.
|
||||
const int bidb = blockIdx.x;
|
||||
// The block index for the head.
|
||||
const int bidh = blockIdx.y;
|
||||
// The block index.
|
||||
// size_t bidx = bidb * params.h + bidh;
|
||||
uint32_t bidx = bidb * params.h + bidh;
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// The distance between two blocks (in bytes).
|
||||
// size_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
uint32_t block_stride_bytes = params.seqlen_q * BYTES_PER_ELEMENT;
|
||||
|
||||
// Set store location for each thread at the beginning of the loop
|
||||
ptr_row_ = ptr_ + bidx * block_stride_bytes;
|
||||
ptr_ += bidx * block_stride_bytes + (lane / 4) * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store(const uint32_t (&data)[MMAS_M * 2]) {
|
||||
int warp = tidx_ / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx_ % Cta_tile::THREADS_PER_WARP;
|
||||
if ((warp == 0) && (lane % 4 == 0)) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT, data[mi * 2 + 0]);
|
||||
fmha::stg(ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT, data[mi * 2 + 1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
inline __device__ void store_row(const uint32_t (&data)[MMAS_M], const int row) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::stg(ptr_row_ + mi * BYTES_PER_MMA + row * BYTES_PER_ELEMENT, data[mi]);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load(uint32_t (&data)[MMAS_M * 2]) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_ + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_ + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Load from global memory.
|
||||
inline __device__ void load_next(uint32_t (&data)[MMAS_M * 2], int move_steps=1) {
|
||||
char *ptr_next = ptr_ + move_steps * ROWS * BYTES_PER_ELEMENT;
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < MMAS_M; ++mi) {
|
||||
// TODO: Not sure if it's right for MMAS_M > 1
|
||||
fmha::ldg(data[mi * 2 + 0], ptr_next + mi * BYTES_PER_MMA + 0 * BYTES_PER_ELEMENT);
|
||||
fmha::ldg(data[mi * 2 + 1], ptr_next + mi * BYTES_PER_MMA + 8 * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Store data to global memory.
|
||||
template <int N>
|
||||
inline __device__ void load_row(uint32_t (&data)[N], const int row[N]) {
|
||||
#pragma unroll
|
||||
for (int ni = 0; ni < N; ++ni) {
|
||||
fmha::ldg(data[ni], ptr_row_ + row[ni] * BYTES_PER_ELEMENT);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move() {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT;
|
||||
}
|
||||
|
||||
// Move the pointer to the next location.
|
||||
inline __device__ void move(const int steps) {
|
||||
ptr_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
ptr_row_ += ROWS * BYTES_PER_ELEMENT * steps;
|
||||
}
|
||||
|
||||
// The pointer.
|
||||
char *ptr_;
|
||||
char *ptr_row_;
|
||||
const int tidx_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,154 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/gemm/gemm.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/layout/layout.h>
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
#include <third_party/cutlass/include/cutlass/transform/threadblock/predicated_tile_iterator.h>
|
||||
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gemm.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/gmem_tile.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/summary_stats.h>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/mma_core_sm75.h>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int S, int D, int STEP, int WARPS_M, int WARPS_N, uint32_t FLAGS = 0x08u, typename elem_type=cutlass::half_t>
|
||||
struct FMHA_kernel_traits {
|
||||
|
||||
// The CTA description for the 1st GEMM.
|
||||
using Cta_tile_p = fmha::Cta_tile_extd<STEP, S, D, WARPS_M, WARPS_N, 1>;
|
||||
// The CTA description for the 2nd GEMM.
|
||||
using Cta_tile_o = fmha::Cta_tile_extd<STEP, D, S, WARPS_M, 1, WARPS_N>;
|
||||
|
||||
// Do we use one buffer for K and V.
|
||||
static constexpr bool SHARE_SMEM_FOR_K_AND_V = (FLAGS & 0x08u) != 0u;
|
||||
// Do we keep K in registers.
|
||||
static constexpr bool K_IN_REGS = (FLAGS & 0x10u) == 0u;
|
||||
// Do we keep V in registers.
|
||||
static constexpr bool V_IN_REGS = (FLAGS & 0x100u) == 0u;
|
||||
|
||||
// The global memory tile to load/store S.
|
||||
using Gmem_tile_s = fmha::Gmem_tile_mma_s<Cta_tile_p>;
|
||||
|
||||
// The global memory tile to store the softmax sum.
|
||||
using Gmem_softmax_sum = fmha::Gmem_summary_stats<Cta_tile_p>;
|
||||
|
||||
// The number of threads.
|
||||
static constexpr int THREADS = Cta_tile_p::THREADS_PER_CTA;
|
||||
// Make sure the number of threads matches both CTAs.
|
||||
static_assert(THREADS == Cta_tile_o::THREADS_PER_CTA, "");
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
#elif defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
|
||||
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
|
||||
#else
|
||||
// using MmaInstructionShape = cutlass::gemm::GemmShape<8, 8, 4>;
|
||||
using MmaInstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
|
||||
// TD [2022-06-02] We don't support Volta (SM70) yet.
|
||||
#endif
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
using Element = elem_type;
|
||||
#else
|
||||
using Element = cutlass::half_t;
|
||||
#endif
|
||||
using ElementAccum = float;
|
||||
|
||||
static_assert(WARPS_M == 1);
|
||||
using ThreadblockShapeQK = cutlass::gemm::GemmShape<STEP, S, D>;
|
||||
using WarpCountQK = cutlass::gemm::GemmShape<WARPS_M, WARPS_N, 1>;
|
||||
using WarpShapeQK = cutlass::gemm::GemmShape<
|
||||
ThreadblockShapeQK::kM,
|
||||
ThreadblockShapeQK::kN / WarpCountQK::kN, ThreadblockShapeQK::kK>;
|
||||
using LayoutQ = cutlass::layout::RowMajor;
|
||||
using LayoutK = cutlass::layout::ColumnMajor;
|
||||
using LayoutP = cutlass::layout::RowMajor;
|
||||
using MmaCoreQK = typename fmha::FMHAMmaCore<
|
||||
ThreadblockShapeQK, WarpShapeQK, MmaInstructionShape, Element, LayoutQ,
|
||||
Element, LayoutK, ElementAccum, LayoutP,
|
||||
cutlass::arch::OpClassTensorOp>;
|
||||
|
||||
using ThreadblockShapePV = cutlass::gemm::GemmShape<STEP, D, S>;
|
||||
using WarpCountPV = cutlass::gemm::GemmShape<WARPS_M, 1, WARPS_N>;
|
||||
using WarpShapePV = cutlass::gemm::GemmShape<ThreadblockShapePV::kM, ThreadblockShapePV::kN, ThreadblockShapePV::kK / WarpCountPV::kK>;
|
||||
using LayoutV = cutlass::layout::RowMajor;
|
||||
using LayoutO = cutlass::layout::RowMajor;
|
||||
using MmaCorePV = typename fmha::FMHAMmaCore<
|
||||
ThreadblockShapePV, WarpShapePV, MmaInstructionShape, Element, LayoutP,
|
||||
Element, LayoutV, ElementAccum, LayoutO,
|
||||
cutlass::arch::OpClassTensorOp>;
|
||||
|
||||
// The global memory tile to load Q.
|
||||
// Copy from mma_piplined_testbed.h
|
||||
using GmemIteratorQ = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<ThreadblockShapeQK::kM, ThreadblockShapeQK::kK>,
|
||||
Element,
|
||||
LayoutQ,
|
||||
0,
|
||||
typename MmaCoreQK::IteratorThreadMapA
|
||||
>;
|
||||
|
||||
// The global memory tile to load K.
|
||||
using GmemIteratorK = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<ThreadblockShapeQK::kK, ThreadblockShapeQK::kN>,
|
||||
Element,
|
||||
LayoutK,
|
||||
1,
|
||||
typename MmaCoreQK::IteratorThreadMapB
|
||||
>;
|
||||
|
||||
// The global memory tile to load V.
|
||||
using GmemIteratorV = cutlass::transform::threadblock::PredicatedTileIterator<
|
||||
cutlass::MatrixShape<ThreadblockShapePV::kK, ThreadblockShapePV::kN>,
|
||||
Element,
|
||||
LayoutV,
|
||||
0,
|
||||
typename MmaCorePV::IteratorThreadMapB
|
||||
>;
|
||||
|
||||
// The shared memory tile to store softmax lse.
|
||||
using Smem_softmax_lse = fmha::Smem_tile_softmax_lse<ThreadblockShapeQK::kM, MmaInstructionShape::kM, WarpCountQK::kM>;
|
||||
|
||||
// The amount of shared memory needed to load Q and K.
|
||||
static constexpr size_t BYTES_PER_SMEM_Q = ThreadblockShapeQK::kM * ThreadblockShapeQK::kK * sizeof(Element);
|
||||
static constexpr size_t BYTES_PER_SMEM_K = ThreadblockShapeQK::kN * ThreadblockShapeQK::kK * sizeof(Element);
|
||||
static constexpr size_t BYTES_PER_SMEM_V = ThreadblockShapePV::kN * ThreadblockShapePV::kK * sizeof(Element);
|
||||
static_assert(BYTES_PER_SMEM_K == BYTES_PER_SMEM_V);
|
||||
static constexpr size_t BYTES_PER_SMEM_QK = BYTES_PER_SMEM_Q + BYTES_PER_SMEM_K;
|
||||
// The extra amount of shared memory needed to load V.
|
||||
static constexpr size_t BYTES_PER_SMEM_V_EXTRA = SHARE_SMEM_FOR_K_AND_V ? 0u : BYTES_PER_SMEM_V;
|
||||
// The amount of shared memory needed for Q, K and V..
|
||||
static constexpr size_t BYTES_PER_SMEM_QKV = BYTES_PER_SMEM_QK + BYTES_PER_SMEM_V_EXTRA;
|
||||
|
||||
};
|
92
aten/src/ATen/native/transformers/cuda/flash_attn/mask.h
Normal file
92
aten/src/ATen/native/transformers/cuda/flash_attn/mask.h
Normal file
@ -0,0 +1,92 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
namespace fmha {
|
||||
|
||||
|
||||
template<typename Cta_tile, bool Is_causal=false>
|
||||
struct Mask {
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
template<typename BInfo>
|
||||
__device__ Mask(const BInfo &binfo, int tidx, const int loop_step_idx_ = 0)
|
||||
: actual_seqlen_k(binfo.actual_seqlen_k - loop_step_idx_ * Cta_tile::N)
|
||||
, loop_step_idx(loop_step_idx_) {
|
||||
|
||||
const int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
const int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
static_assert(Cta_tile::WARPS_K == 1, "");
|
||||
|
||||
// find the warp in the Cta tile
|
||||
const int warp_n = (warp / Cta_tile::WARPS_M);
|
||||
const int warp_m = (warp % Cta_tile::WARPS_M);
|
||||
// decompose warp into 8x4 tile
|
||||
const int quad = lane / 4;
|
||||
const int tid = (lane % 4) * 2;
|
||||
row = warp_m * 16 + quad;
|
||||
// col = warp_n * 16 + tid;
|
||||
col = warp_n * Mma_tile::N_PER_MMA * Mma_tile::MMAS_N + tid;
|
||||
}
|
||||
|
||||
inline __device__ bool is_valid(const int mi, const int ni, const int ii, const int jj) const {
|
||||
|
||||
// ii and jj iterate over the 2x4 fragment
|
||||
// const int current_col = (Is_causal ? loop_step_idx * Cta_tile::N : 0) + ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
// const int current_col = ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_col = ni * Mma_tile::N_PER_MMA + col + (jj & 2) * 4 + (jj & 1);
|
||||
const int current_row = row_offset + ii * 8;
|
||||
const bool col_valid = current_col < actual_seqlen_k;
|
||||
// const bool col_valid = (ni * Mma_tile::N_PER_MMA_PER_CTA + col + (jj & 2) * 4 + (jj & 1)) < actual_seqlen_k;
|
||||
//&& (row + mi * Mma_tile::M_PER_MMA_PER_CTA + ii * 8) < actual_seqlen_k;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("current_col=%d, current_row=%d, actual_seqlen_k=%d, col_valid=%d, all_valid=%d\n", current_col, current_row, actual_seqlen_k, col_valid, all_valid);
|
||||
// }
|
||||
return Is_causal ? col_valid && (current_col + loop_step_idx * Cta_tile::N <= current_row) : col_valid;
|
||||
// return row_valid && col_valid;
|
||||
}
|
||||
|
||||
//BERT Mask: if upper left is invalid, none are valid
|
||||
inline __device__ bool any_valid(const int mi, const int ni) const {
|
||||
return is_valid(mi, ni, 0, 0) || is_valid(mi, ni, 1, 0);
|
||||
}
|
||||
|
||||
inline __device__ void load(const int it) {
|
||||
row_offset = it * Cta_tile::M + row;
|
||||
}
|
||||
int row_offset;
|
||||
|
||||
int row;
|
||||
int col;
|
||||
const int loop_step_idx;
|
||||
const int actual_seqlen_k;
|
||||
};
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,382 @@
|
||||
// Adapted from cutlass/gemm/threadblock/default_mma_core_sm75.h
|
||||
// This is very similar, except we make it work for head_dim=128.
|
||||
// The original cutlass version only allows kK of the thread block to be
|
||||
// at most 64. Here we set kCrosswise = max(64, ThreadblockShape::kK) instead.
|
||||
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
#include <third_party/cutlass/include/cutlass/platform/platform.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/numeric_types.h>
|
||||
#include <third_party/cutlass/include/cutlass/matrix_shape.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/layout/tensor_op_multiplicand_sm75.h>
|
||||
#include <third_party/cutlass/include/cutlass/transform/pitch_linear_thread_map.h>
|
||||
#include <third_party/cutlass/include/cutlass/transform/threadblock/regular_tile_iterator_tensor_op.h>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/gemm/warp/default_mma_tensor_op.h>
|
||||
#include <third_party/cutlass/include/cutlass/gemm/threadblock/default_mma_core.h>
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace fmha {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Template defininng default matrix multiply operators inferred from threadblock tile size,
|
||||
/// global memory data layout, and target math instruction.
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator
|
||||
typename Shape,
|
||||
/// Shape of warp-level matrix multiply operator
|
||||
typename WarpShape,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape,
|
||||
/// Element data type of A operand
|
||||
typename ElementA,
|
||||
/// Layout of operand A
|
||||
typename LayoutA,
|
||||
/// Element data type of B operand
|
||||
typename ElementB,
|
||||
/// Layout of operand B
|
||||
typename LayoutB,
|
||||
/// Data type of accumulator
|
||||
typename ElementC,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC,
|
||||
/// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp)
|
||||
typename OperatorClass,
|
||||
/// Operation performed by MMA
|
||||
typename Operator = cutlass::arch::OpMultiplyAdd
|
||||
>
|
||||
struct FMHAMmaCore;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: column-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_>
|
||||
struct FMHAMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
cutlass::layout::RowMajor, ElementB_, cutlass::layout::ColumnMajor,
|
||||
ElementC_, LayoutC_, cutlass::arch::OpClassTensorOp, Operator_
|
||||
> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = cutlass::gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK
|
||||
>;
|
||||
|
||||
// Divisibility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<cutlass::arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Cutlass only supports Crosswise at most 64
|
||||
static int const kCrosswise = std::min(Shape::kK, 64);
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
static int const kWarpThreadArrangementContiguousB =
|
||||
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedB =
|
||||
kWarpSize / kWarpThreadArrangementContiguousB;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, kCrosswise>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementB>::value, kCrosswise>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap<
|
||||
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
|
||||
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator<
|
||||
cutlass::MatrixShape<Shape::kM, Shape::kK>,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
0,
|
||||
IteratorThreadMapA
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap<
|
||||
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kN>, kThreads,
|
||||
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousB,
|
||||
kWarpThreadArrangementStridedB>,
|
||||
kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator<
|
||||
cutlass::MatrixShape<Shape::kK, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
1,
|
||||
IteratorThreadMapB
|
||||
>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
|
||||
MmaTensorOp,
|
||||
cutlass::MatrixShape<0, 0>,
|
||||
cutlass::MatrixShape<0, 0>,
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization:
|
||||
///
|
||||
/// A: row-major
|
||||
/// B: row-major
|
||||
/// Operator: tensor op class
|
||||
///
|
||||
/// This uses the default warp-level operator given tile sizes
|
||||
template <
|
||||
/// Shape of threadblock-scoped matrix multiply operator (concept:
|
||||
/// GemmShape)
|
||||
typename Shape_,
|
||||
/// Shape of warp-level matrix multiply operator (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename InstructionShape_,
|
||||
/// Data type of A operand
|
||||
typename ElementA_,
|
||||
/// Data type of B operand
|
||||
typename ElementB_,
|
||||
/// Data type of accumulator
|
||||
typename ElementC_,
|
||||
/// Layout of accumulator
|
||||
typename LayoutC_,
|
||||
/// Operation performed by MMA
|
||||
typename Operator_>
|
||||
struct FMHAMmaCore<Shape_, WarpShape_, InstructionShape_, ElementA_,
|
||||
cutlass::layout::RowMajor, ElementB_, cutlass::layout::RowMajor, ElementC_,
|
||||
LayoutC_, cutlass::arch::OpClassTensorOp, Operator_
|
||||
> {
|
||||
using Shape = Shape_;
|
||||
using WarpShape = WarpShape_;
|
||||
using InstructionShape = InstructionShape_;
|
||||
using ElementA = ElementA_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using ElementB = ElementB_;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using ElementC = ElementC_;
|
||||
using LayoutC = LayoutC_;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
/// Number of warps present
|
||||
using WarpCount = cutlass::gemm::GemmShape<
|
||||
Shape::kM / WarpShape::kM,
|
||||
Shape::kN / WarpShape::kN,
|
||||
Shape::kK / WarpShape::kK
|
||||
>;
|
||||
|
||||
// Divisility requirements
|
||||
static_assert(
|
||||
!(Shape::kM % WarpShape::kM) &&
|
||||
!(Shape::kN % WarpShape::kN),
|
||||
"Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size."
|
||||
);
|
||||
|
||||
/// Number of threads per warp
|
||||
static int const kWarpSize = cutlass::gemm::warp::WarpSize<cutlass::arch::OpClassTensorOp>::value;
|
||||
|
||||
/// Number of threads total
|
||||
static int const kThreads = WarpCount::kCount * kWarpSize;
|
||||
|
||||
/// Size of a threadblock-scoped access
|
||||
static int const kAccessSizeInBits = 128;
|
||||
|
||||
/// Cutlass only supports Crosswise at most 64
|
||||
static int const kCrosswise = std::min(Shape::kK, 64);
|
||||
|
||||
/// Default Operator
|
||||
using Operator = Operator_;
|
||||
|
||||
// Warp thread arrangement
|
||||
static int const kWarpThreadArrangementContiguousA =
|
||||
kCrosswise / (kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value);
|
||||
|
||||
static int const kWarpThreadArrangementStridedA =
|
||||
kWarpSize / kWarpThreadArrangementContiguousA;
|
||||
|
||||
//
|
||||
// Shared memory layouts
|
||||
//
|
||||
|
||||
using SmemLayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
|
||||
cutlass::sizeof_bits<ElementA>::value, kCrosswise>;
|
||||
|
||||
// Shared memory layout
|
||||
using SmemLayoutB = cutlass::layout::RowMajorTensorOpMultiplicandCongruous<
|
||||
cutlass::sizeof_bits<ElementB>::value, int(128 / sizeof(ElementB))>;
|
||||
|
||||
//
|
||||
// Iterators to write to shared memory
|
||||
//
|
||||
|
||||
/// ThreadMap of iterator A
|
||||
using IteratorThreadMapA = cutlass::transform::PitchLinearWarpRakedThreadMap<
|
||||
cutlass::layout::PitchLinearShape<Shape::kK, Shape::kM>, kThreads,
|
||||
cutlass::layout::PitchLinearShape<kWarpThreadArrangementContiguousA,
|
||||
kWarpThreadArrangementStridedA>,
|
||||
kAccessSizeInBits / cutlass::sizeof_bits<ElementA>::value>;
|
||||
|
||||
/// Shared memory iterator to A operand
|
||||
using SmemIteratorA = cutlass::transform::threadblock::RegularTileIterator<
|
||||
cutlass::MatrixShape<Shape::kM, Shape::kK>,
|
||||
ElementA,
|
||||
SmemLayoutA,
|
||||
0,
|
||||
IteratorThreadMapA
|
||||
>;
|
||||
|
||||
/// ThreadMap of iterator B
|
||||
using IteratorThreadMapB = cutlass::transform::PitchLinearWarpRakedThreadMap<
|
||||
cutlass::layout::PitchLinearShape<Shape::kN, Shape::kK>,
|
||||
kThreads,
|
||||
cutlass::layout::PitchLinearShape<8, 4>,
|
||||
kAccessSizeInBits / cutlass::sizeof_bits<ElementB>::value
|
||||
>;
|
||||
|
||||
/// Shared memory iterator to B operand
|
||||
using SmemIteratorB = cutlass::transform::threadblock::RegularTileIterator<
|
||||
cutlass::MatrixShape<Shape::kK, Shape::kN>,
|
||||
ElementB,
|
||||
SmemLayoutB,
|
||||
0,
|
||||
IteratorThreadMapB
|
||||
>;
|
||||
|
||||
//
|
||||
// Warp-level matrix multiply operator
|
||||
//
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
|
||||
WarpShape, InstructionShape, ElementA, SmemLayoutA, ElementB, SmemLayoutB,
|
||||
ElementC, LayoutC, Operator, WarpCount::kK>::Type;
|
||||
|
||||
/// Policy used to define MmaPipelined
|
||||
using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy<
|
||||
MmaTensorOp,
|
||||
cutlass::MatrixShape<0, 0>,
|
||||
cutlass::MatrixShape<0, 0>,
|
||||
WarpCount::kK
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
} // namespace fmha
|
146
aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh
Normal file
146
aten/src/ATen/native/transformers/cuda/flash_attn/philox.cuh
Normal file
@ -0,0 +1,146 @@
|
||||
// Pytorch also has an implementation of Philox RNG: https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/codegen/cuda/runtime/random_numbers.cu
|
||||
#pragma once
|
||||
// Philox CUDA.
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace {
|
||||
|
||||
class Philox {
|
||||
public:
|
||||
__device__ inline Philox(unsigned long long seed,
|
||||
unsigned long long subsequence,
|
||||
unsigned long long offset)
|
||||
: STATE(0)
|
||||
, key(reinterpret_cast<const uint2&>(seed)) {
|
||||
//key.x = (unsigned int)seed;
|
||||
//key.y = (unsigned int)(seed >> 32);
|
||||
//counter = make_uint4(0, 0, 0, 0);
|
||||
//counter.z = (unsigned int)(subsequence);
|
||||
//counter.w = (unsigned int)(subsequence >> 32);
|
||||
//STATE = 0;
|
||||
//incr_n(offset / 4);
|
||||
|
||||
// key = reinterpret_cast<const uint2&>(seed);
|
||||
ull2 * tmp = reinterpret_cast<ull2*>(&counter);
|
||||
tmp->x = offset / 4;
|
||||
tmp->y = subsequence;
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %d, %d, %d, %d\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
__device__ inline uint4 operator()() {
|
||||
// if (STATE == 0) {
|
||||
uint4 counter_ = counter;
|
||||
uint2 key_ = key;
|
||||
// 7-round philox
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 6; i++) {
|
||||
counter_ = single_round(counter_, key_);
|
||||
key_.x += (kPhilox10A);
|
||||
key_.y += (kPhilox10B);
|
||||
}
|
||||
// output = single_round(counter_, key_);
|
||||
uint4 output = single_round(counter_, key_);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Philox counter: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// printf("Philox output: %u, %u, %u, %u\n", output.x, output.y, output.z, output.w);
|
||||
// }
|
||||
incr();
|
||||
// }
|
||||
// return a float4 directly
|
||||
// unsigned long ret;
|
||||
// switch(STATE) {
|
||||
// case 0: ret = output.x; break;
|
||||
// case 1: ret = output.y; break;
|
||||
// case 2: ret = output.z; break;
|
||||
// case 3: ret = output.w; break;
|
||||
//}
|
||||
// STATE = (STATE + 1) % 4;
|
||||
return output;
|
||||
}
|
||||
|
||||
private:
|
||||
struct ull2 {
|
||||
uint64_t x;
|
||||
uint64_t y;
|
||||
};
|
||||
uint4 counter;
|
||||
// uint4 output;
|
||||
const uint2 key;
|
||||
unsigned int STATE;
|
||||
__device__ inline void incr_n(unsigned long long n) {
|
||||
unsigned int nlo = (unsigned int)(n);
|
||||
unsigned int nhi = (unsigned int)(n >> 32);
|
||||
counter.x += nlo;
|
||||
if (counter.x < nlo)
|
||||
nhi++;
|
||||
counter.y += nhi;
|
||||
if (nhi <= counter.y)
|
||||
return;
|
||||
if (++counter.z)
|
||||
return;
|
||||
++counter.w;
|
||||
}
|
||||
|
||||
__device__ uint4 incr128 (uint4 ctr)
|
||||
{
|
||||
uint4 res;
|
||||
asm ("add.cc.u32 %0, %4, %8;\n\t"
|
||||
"addc.cc.u32 %1, %5, %9;\n\t"
|
||||
"addc.cc.u32 %2, %6, %10;\n\t"
|
||||
"addc.u32 %3, %7, %11;\n\t"
|
||||
: "=r"(res.x), "=r"(res.y), "=r"(res.z), "=r"(res.w)
|
||||
: "r"(ctr.x), "r"(ctr.y), "r"(ctr.z), "r"(ctr.w),
|
||||
"n"(1), "n"(0), "n"(0), "n"(0));
|
||||
return res;
|
||||
}
|
||||
|
||||
__device__ inline void incr() {
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter before: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
counter = incr128(counter);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("Counter after: %u, %u, %u, %u\n", counter.x, counter.y, counter.z, counter.w);
|
||||
// }
|
||||
}
|
||||
__device__ unsigned int mulhilo32(unsigned int a, unsigned int b,
|
||||
unsigned int *result_high) {
|
||||
*result_high = __umulhi(a, b);
|
||||
return a * b;
|
||||
}
|
||||
__device__ uint2 mulhilo32_v2 (const unsigned int a, const unsigned int b)
|
||||
{
|
||||
uint2 *res;
|
||||
unsigned long long tmp;
|
||||
asm ("mul.wide.u32 %0, %1, %2;\n\t"
|
||||
: "=l"(tmp)
|
||||
: "r"(a), "r"(b));
|
||||
res = (uint2*)(&tmp);
|
||||
return *res;
|
||||
}
|
||||
__device__ inline uint4 single_round(const uint4 ctr, const uint2 key) {
|
||||
//unsigned int hi0;
|
||||
//unsigned int hi1;
|
||||
//unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0);
|
||||
//unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1);
|
||||
//uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0};
|
||||
uint2 res0 = mulhilo32_v2(kPhiloxSA, ctr.x);
|
||||
uint2 res1 = mulhilo32_v2(kPhiloxSB, ctr.z);
|
||||
uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
|
||||
return ret;
|
||||
}
|
||||
static const unsigned long kPhilox10A = 0x9E3779B9;
|
||||
static const unsigned long kPhilox10B = 0xBB67AE85;
|
||||
static const unsigned long kPhiloxSA = 0xD2511F53;
|
||||
static const unsigned long kPhiloxSB = 0xCD9E8D57;
|
||||
};
|
||||
// Inverse of 2^32.
|
||||
constexpr float M_RAN_INVM32 = 2.3283064e-10f;
|
||||
__device__ __inline__ float4 uniform4(const uint4 x) {
|
||||
return make_float4(x.x * M_RAN_INVM32, x.y * M_RAN_INVM32, x.z * M_RAN_INVM32,
|
||||
x.w * M_RAN_INVM32);
|
||||
}
|
||||
|
||||
} // namespace
|
446
aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h
Normal file
446
aten/src/ATen/native/transformers/cuda/flash_attn/softmax.h
Normal file
@ -0,0 +1,446 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cmath>
|
||||
#include <ATen/native/transformers/cuda/flash_attn/philox.cuh>
|
||||
|
||||
#include <third_party/cutlass/include/cutlass/cutlass.h>
|
||||
#include <third_party/cutlass/include/cutlass/array.h>
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp_(float x, float max) {
|
||||
return __expf(x - max);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ float apply_exp2_(float x, float max) {
|
||||
return exp2f(x - max);
|
||||
// With fast-math, this produces the same PTX instruction as the assembly below
|
||||
// float diff = x - max;
|
||||
// float res;
|
||||
// asm ("ex2.approx.ftz.f32 %0, %1;\n\t" : "=f"(res) : "f"(diff));
|
||||
// return res;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int COLS> struct ReadType {};
|
||||
template<> struct ReadType<4> { using T = float;};
|
||||
template<> struct ReadType<8> { using T = float2;};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename Cta_tile, typename Kernel_traits>
|
||||
struct Smem_tile_reduce {
|
||||
// Helper class to distribute MMA tiles reduced over rows per warp over quads.
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
|
||||
|
||||
static constexpr int ROWS = WARPS_M * MMAS_M * 16;
|
||||
static constexpr int COLS = WARPS_N;
|
||||
static_assert(COLS == 4 || COLS == 8);
|
||||
static constexpr int ROWS_PER_XOR_PATTERN = (COLS == 8) ? 4 : 8;
|
||||
static constexpr int BYTES_PER_TILE = ROWS * COLS * sizeof(float);
|
||||
static constexpr int ELTS_PER_TILE = ROWS * COLS;
|
||||
|
||||
using read_t = typename ReadType<COLS>::T;
|
||||
|
||||
__device__ inline Smem_tile_reduce(float *smem_, const int tidx) {
|
||||
|
||||
int lane = tidx % 32;
|
||||
int warp = tidx / 32;
|
||||
|
||||
int warp_m = warp % WARPS_M;
|
||||
int warp_n = warp / WARPS_M;
|
||||
|
||||
qid_ = lane % 4;
|
||||
int qp = lane / 4;
|
||||
|
||||
// Swizzle the column to avoid 2-fold bank conflicts when we have 8 warps.
|
||||
// This won't affect reading as we assume commutative reduction ops.
|
||||
const int col = warp_n ^ (qp / ROWS_PER_XOR_PATTERN);
|
||||
smem_write_ = &smem_[warp_m * 16 * MMAS_M * WARPS_N + qp * WARPS_N + col];
|
||||
smem_read_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qp * 4 + qid_];
|
||||
smem_read_row_ = &reinterpret_cast<read_t *>(smem_)[warp_m * 16 * MMAS_M * 4 + qid_];
|
||||
|
||||
}
|
||||
|
||||
__device__ inline void store(float (&frag)[2 * MMAS_M]) {
|
||||
if( qid_ == 0 ) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * WARPS_N;
|
||||
smem_write_[offset + 0 * 8 * WARPS_N] = frag[mi * 2 + 0];
|
||||
smem_write_[offset + 1 * 8 * WARPS_N] = frag[mi * 2 + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load(read_t (&frag)[2 * MMAS_M]) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi * 2 + 0] = smem_read_[offset + 0 * 8 * 4];
|
||||
frag[mi * 2 + 1] = smem_read_[offset + 1 * 8 * 4];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ inline void load_row(read_t (&frag)[MMAS_M], int row) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
int offset = mi * 16 * 4;
|
||||
frag[mi] = smem_read_row_[offset + 0 * 8 * 4 + row * 4];
|
||||
}
|
||||
}
|
||||
|
||||
int qid_;
|
||||
float *smem_write_;
|
||||
read_t *smem_read_;
|
||||
read_t *smem_read_row_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax_base {
|
||||
|
||||
// The Mma tile.
|
||||
using Mma_tile = fmha::Hmma_tile<Cta_tile>;
|
||||
|
||||
// The number of MMAs in M/N dimensions.
|
||||
static constexpr int MMAS_M = Mma_tile::MMAS_M;
|
||||
static constexpr int MMAS_N = Mma_tile::MMAS_N;
|
||||
|
||||
// The number of groups of warp such that we have at most 4 warps writing consecutive elements.
|
||||
static constexpr int GROUPS = fmha::DivUpConstexpr(Cta_tile::WARPS_N, 4);
|
||||
// The number of elements that we are going to store per row.
|
||||
static constexpr int ELEMENTS_PER_ROW = Cta_tile::WARPS_N / GROUPS;
|
||||
// The number of rows.
|
||||
static constexpr int ROWS = Cta_tile::M * GROUPS;
|
||||
// The total number of elements.
|
||||
static constexpr int ELEMENTS = ROWS * ELEMENTS_PER_ROW;
|
||||
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax_base(const Params ¶ms, void *smem, int tidx)
|
||||
: // packed_mask_ptr_(reinterpret_cast<const char*>(params.packed_mask_ptr)),
|
||||
smem_(reinterpret_cast<float *>(smem)), tidx_(tidx) {
|
||||
|
||||
// Extract the position in the warp.
|
||||
int warp = tidx / Cta_tile::THREADS_PER_WARP;
|
||||
int lane = tidx % Cta_tile::THREADS_PER_WARP;
|
||||
|
||||
// Decompose the warp index into M and N.
|
||||
int warp_m = warp % Cta_tile::WARPS_M;
|
||||
int warp_n = warp / Cta_tile::WARPS_M;
|
||||
|
||||
// Decompose the warp-n index into group/position-inside-the-group.
|
||||
int warp_g = warp_n / ELEMENTS_PER_ROW;
|
||||
int warp_i = warp_n % ELEMENTS_PER_ROW;
|
||||
|
||||
// The location written by the threads.
|
||||
int write_row = warp_g * (ROWS / GROUPS) + warp_m * Mma_tile::M_PER_MMA + lane / 4;
|
||||
int write_col = warp_i;
|
||||
|
||||
// Assemble the write pointer.
|
||||
smem_write_ = &smem_[write_row * ELEMENTS_PER_ROW + write_col];
|
||||
|
||||
// Assemble the read pointer.
|
||||
smem_read_ = &smem_[warp_m * Mma_tile::M_PER_MMA + lane / 4];
|
||||
}
|
||||
|
||||
template<bool zero=false, typename Mask>
|
||||
inline __device__ void apply_mask(const Mask &mask) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ii = 0; ii < 2; ++ii ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
#pragma unroll
|
||||
for( int jj = 0; jj < 4; ++jj ) {
|
||||
if( !mask.is_valid(mi, ni, ii, jj) ) {
|
||||
elt_[2 * mi + ii][4 * ni + jj] = zero ? 0.f : -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the exp to all the elements.
|
||||
template <bool scale_max=true>
|
||||
inline __device__ void scale_apply_exp(const float (&max)[MMAS_M * 2], const float scale_) {
|
||||
const float max_scale = scale_max ? scale_ * M_LOG2E : M_LOG2E;
|
||||
const float scale = scale_ * M_LOG2E;
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M * 2; ++mi ) {
|
||||
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
|
||||
// max * log_2(e)) This allows the compiler to use the ffma
|
||||
// instruction instead of fadd and fmul separately.
|
||||
const float max_scaled = max[mi] * max_scale;
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N * 4; ++ni ) {
|
||||
elt_[mi][ni] = apply_exp2_(elt_[mi][ni] * scale, max_scaled);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni++ ) {
|
||||
uint4 random_uint4 = ph();
|
||||
uint16_t (&rnd)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd.x, rnd.y, rnd.z, rnd.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(rnd[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <bool encode_dropout_in_sign_bit=false>
|
||||
inline __device__ void apply_dropout_16bits(Philox &ph0, Philox &ph1, uint16_t p_dropout_in_uint16_t) {
|
||||
// We encode the dropout pattern in the sign bit of the non-negative
|
||||
// softmax to distinguish from pre-existing zeros
|
||||
auto encode_dropout = [](bool keep, float val) {
|
||||
return keep ? val : (encode_dropout_in_sign_bit ? -val : float(0));
|
||||
};
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; mi++ ) {
|
||||
static_assert(MMAS_N % 2 == 0);
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ni += 2 ) {
|
||||
uint4 random_uint4 = ph0();
|
||||
uint16_t (&rnd0)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd0.x, rnd0.y, rnd0.z, rnd0.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * ni + jj] =
|
||||
encode_dropout(rnd0[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * ni + jj]);
|
||||
}
|
||||
}
|
||||
random_uint4 = ph1();
|
||||
uint16_t (&rnd1)[8] = reinterpret_cast<uint16_t (&)[8]>(random_uint4);
|
||||
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
|
||||
// printf("ni = %d, ph Philox: %u, %u, %u, %u\n", ni, rnd1.x, rnd1.y, rnd1.z, rnd1.w);
|
||||
// }
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 2; ++ii) {
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < 4; ++jj) {
|
||||
elt_[mi * 2 + ii][4 * (ni + 1) + jj] =
|
||||
encode_dropout(rnd1[ii * 4 + jj] <= p_dropout_in_uint16_t, elt_[mi * 2 + ii][4 * (ni + 1) + jj]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Shared memory for the CTA-wide reduction.
|
||||
float *smem_, *smem_write_, *smem_read_;
|
||||
// The current thread index.
|
||||
int tidx_;
|
||||
// The elements.
|
||||
float elt_[MMAS_M * 2][MMAS_N * 4];
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Cta_tile, typename Kernel_traits>
|
||||
struct Softmax : public Softmax_base<Cta_tile, Kernel_traits> {
|
||||
|
||||
// The base class.
|
||||
using Base = Softmax_base<Cta_tile, Kernel_traits>;
|
||||
|
||||
static constexpr int WARPS_M = Cta_tile::WARPS_M;
|
||||
static constexpr int WARPS_N = Cta_tile::WARPS_N;
|
||||
// The MMAs.
|
||||
static constexpr int MMAS_M = Base::MMAS_M;
|
||||
static constexpr int MMAS_N = Base::MMAS_N;
|
||||
|
||||
using Smem_tile_red = Smem_tile_reduce<Cta_tile, Kernel_traits>;
|
||||
static_assert(Smem_tile_red::ELTS_PER_TILE == Cta_tile::M * WARPS_N);
|
||||
// Ctor.
|
||||
template<typename Params>
|
||||
inline __device__ Softmax(const Params ¶ms, void *smem, int tidx)
|
||||
: Base(params, smem, tidx)
|
||||
, smem_sum_(static_cast<float*>(smem), tidx)
|
||||
, smem_max_(static_cast<float*>(smem) + Smem_tile_red::ELTS_PER_TILE, tidx) {
|
||||
}
|
||||
|
||||
// Pack the data to a fragment for the next GEMM.
|
||||
inline __device__ void pack_noconvert(cutlass::Array<float, MMAS_M * MMAS_N * 8> &frag) const {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ki = 0; ki < MMAS_N; ++ki ) {
|
||||
// 1st row - 4 elements per row.
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 0] = this->elt_[2 * mi + 0][4 * ki + 0];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 1] = this->elt_[2 * mi + 0][4 * ki + 1];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 4] = this->elt_[2 * mi + 0][4 * ki + 2];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 5] = this->elt_[2 * mi + 0][4 * ki + 3];
|
||||
// 2nd row - 4 elements per row.
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 2] = this->elt_[2 * mi + 1][4 * ki + 0];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 3] = this->elt_[2 * mi + 1][4 * ki + 1];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 6] = this->elt_[2 * mi + 1][4 * ki + 2];
|
||||
frag[ki * MMAS_M * 8 + mi * 8 + 7] = this->elt_[2 * mi + 1][4 * ki + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FragmentC>
|
||||
inline __device__ void unpack_noscale(const FragmentC (&acc)) {
|
||||
static_assert(FragmentC::kElements == MMAS_M * MMAS_N * 8, "");
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < MMAS_M; ++mi ) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < MMAS_N; ++ni ) {
|
||||
// 1st row - 4 elements per row.
|
||||
this->elt_[2 * mi + 0][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 0];
|
||||
this->elt_[2 * mi + 0][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 1];
|
||||
this->elt_[2 * mi + 0][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 4];
|
||||
this->elt_[2 * mi + 0][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 5];
|
||||
// 2nd row - 4 elements per row.
|
||||
this->elt_[2 * mi + 1][4 * ni + 0] = acc[mi * MMAS_N * 8 + ni * 8 + 2];
|
||||
this->elt_[2 * mi + 1][4 * ni + 1] = acc[mi * MMAS_N * 8 + ni * 8 + 3];
|
||||
this->elt_[2 * mi + 1][4 * ni + 2] = acc[mi * MMAS_N * 8 + ni * 8 + 6];
|
||||
this->elt_[2 * mi + 1][4 * ni + 3] = acc[mi * MMAS_N * 8 + ni * 8 + 7];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void thread_reduce_(float (&frag)[2 * MMAS_M], Operator &op) {
|
||||
#pragma unroll
|
||||
for( int mi = 0; mi < 2 * MMAS_M; mi++ ) {
|
||||
frag[mi] = zero_init ? this->elt_[mi][0] : op(frag[mi], this->elt_[mi][0]);
|
||||
#pragma unroll
|
||||
for( int ni = 1; ni < 4 * MMAS_N; ni++ ) {
|
||||
frag[mi] = op(frag[mi], this->elt_[mi][ni]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<bool zero_init=true, typename Operator>
|
||||
__device__ inline void reduce_(float (&frag)[2 * MMAS_M], Operator &op, Smem_tile_red & smem_red) {
|
||||
thread_reduce_<zero_init>(frag, op);
|
||||
quad_reduce(frag, frag, op);
|
||||
smem_red.store(frag);
|
||||
__syncthreads();
|
||||
typename Smem_tile_red::read_t tmp[2 * MMAS_M];
|
||||
smem_red.load(tmp);
|
||||
quad_allreduce(frag, tmp, op);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_max(float (&frag)[2 * MMAS_M]){
|
||||
MaxOp<float> max;
|
||||
reduce_<zero_init>(frag, max, smem_max_);
|
||||
}
|
||||
|
||||
__device__ inline void reduce_sum(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
reduce_(frag, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<bool zero_init=true>
|
||||
__device__ inline void reduce_sum_before_sync_(float (&frag)[2 * MMAS_M]){
|
||||
SumOp<float> sum;
|
||||
thread_reduce_<zero_init>(frag, sum);
|
||||
quad_reduce(frag, frag, sum);
|
||||
smem_sum_.store(frag);
|
||||
}
|
||||
|
||||
template<int NROWS, typename Operator>
|
||||
__device__ inline void reduce_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS],
|
||||
Operator &op, Smem_tile_red & smem_red) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < NROWS; ii++) {
|
||||
typename Smem_tile_red::read_t tmp[MMAS_M];
|
||||
smem_red.load_row(tmp, rows[ii]);
|
||||
quad_allreduce(frag[ii], tmp, op);
|
||||
}
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_sum_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
SumOp<float> sum;
|
||||
reduce_after_sync_(frag, rows, sum, smem_sum_);
|
||||
}
|
||||
|
||||
template<int NROWS>
|
||||
__device__ inline void reduce_max_after_sync_(float (&frag)[NROWS][MMAS_M],
|
||||
const int (&rows)[NROWS]){
|
||||
MaxOp<float> max;
|
||||
reduce_after_sync_(frag, rows, max, smem_max_);
|
||||
}
|
||||
|
||||
Smem_tile_red smem_max_;
|
||||
Smem_tile_red smem_sum_;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
@ -0,0 +1,25 @@
|
||||
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
@ -0,0 +1,55 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022, Tri Dao.
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int kRows, int kRowsPerMma, int kWarpCountM>
|
||||
struct Smem_tile_softmax_lse {
|
||||
|
||||
static constexpr int kMmaM = (kRows / kWarpCountM) / kRowsPerMma;
|
||||
static_assert(kMmaM * kRowsPerMma * kWarpCountM == kRows);
|
||||
// static_assert(kWarpCountM == 1);
|
||||
// Otherwise we might need to check warp_idx / kWarpCountM == 0 instead of just warp_idx == 0
|
||||
|
||||
// The size of one buffer in bytes in shared memory.
|
||||
static constexpr size_t BYTES_PER_TILE = kRows * sizeof(float);
|
||||
|
||||
inline __device__ Smem_tile_softmax_lse(float *smem) : smem_(smem) {
|
||||
}
|
||||
|
||||
inline __device__ void store_pair(const float (&sum)[kMmaM * 2]) {
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
// This makes a difference of 50us for BERT.
|
||||
// const int warp_idx = threadIdx.x / 32;
|
||||
const int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
const int lane_idx = threadIdx.x % 32;
|
||||
const int warp_n = warp_idx / kWarpCountM;
|
||||
// Extract the position in the warp.
|
||||
const int row = lane_idx / 4;
|
||||
if ((lane_idx % 4 == 0) && (warp_n == 0)) {
|
||||
#pragma unroll
|
||||
for (int mi = 0; mi < kMmaM; ++mi) {
|
||||
smem_[mi * kRowsPerMma + row + 0] = sum[mi * 2 + 0];
|
||||
smem_[mi * kRowsPerMma + row + 8] = sum[mi * 2 + 1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int N>
|
||||
inline __device__ void load(float (&sum)[N], const int (&row)[N]) {
|
||||
#pragma unroll
|
||||
for( int ni = 0; ni < N; ni++ ) {
|
||||
sum[ni] = smem_[row[ni]];
|
||||
}
|
||||
}
|
||||
|
||||
float * const smem_;
|
||||
};
|
||||
|
||||
} // namespace fmha
|
404
aten/src/ATen/native/transformers/cuda/flash_attn/utils.h
Normal file
404
aten/src/ATen/native/transformers/cuda/flash_attn/utils.h
Normal file
@ -0,0 +1,404 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
* * Redistributions of source code must retain the above copyright
|
||||
* notice, this list of conditions and the following disclaimer.
|
||||
* * Redistributions in binary form must reproduce the above copyright
|
||||
* notice, this list of conditions and the following disclaimer in the
|
||||
* documentation and/or other materials provided with the distribution.
|
||||
* * Neither the name of the NVIDIA CORPORATION nor the
|
||||
* names of its contributors may be used to endorse or promote products
|
||||
* derived from this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
|
||||
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
|
||||
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
******************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdint>
|
||||
#include <cstdlib>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
// #include <cuda_fp16.h>
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
#endif
|
||||
|
||||
extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void *ptr);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace fmha {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Row {};
|
||||
struct Col {};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int M, int N >
|
||||
struct Div_up {
|
||||
enum { VALUE = (M + N-1) / N };
|
||||
};
|
||||
|
||||
constexpr int DivUpConstexpr(int M, int N) { return (M + N - 1) / N; }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int A, int B >
|
||||
struct Max {
|
||||
enum { VALUE = A >= B ? A : B };
|
||||
};
|
||||
|
||||
constexpr int MaxConstexpr(int A, int B) { return A >= B ? A : B; }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int A, int B, int C >
|
||||
struct Max_3 {
|
||||
enum { VALUE = Max<Max<A, B>::VALUE, C>::VALUE };
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int A, int B >
|
||||
struct Min {
|
||||
enum { VALUE = A <= B ? A : B };
|
||||
};
|
||||
|
||||
constexpr int MinConstexpr(int A, int B) { return A <= B ? A : B; }
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< int SIZE_IN_BYTES >
|
||||
struct Uint_from_size_in_bytes {
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Uint_from_size_in_bytes<1> {
|
||||
using Type = uint8_t;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Uint_from_size_in_bytes<2> {
|
||||
using Type = uint16_t;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Uint_from_size_in_bytes<4> {
|
||||
using Type = uint32_t;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Uint_from_size_in_bytes<8> {
|
||||
using Type = uint2;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Uint_from_size_in_bytes<16> {
|
||||
using Type = uint4;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template< typename T >
|
||||
inline __device__ __host__ T div_up(T m, T n) {
|
||||
return (m + n-1) / n;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t hrelu2(uint32_t x);
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t hrelu2<__half>(uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile( "max.f16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#else
|
||||
asm volatile( \
|
||||
"{\n" \
|
||||
"\t .reg .f16x2 sela;\n" \
|
||||
"\t set.gtu.u32.f16x2 sela, %1, %2;\n" \
|
||||
"\t and.b32 %0, sela, %1;\n"
|
||||
"}\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template<>
|
||||
inline __device__ uint32_t hrelu2<__nv_bfloat16>(uint32_t x) {
|
||||
uint32_t res;
|
||||
const uint32_t zero = 0u;
|
||||
asm volatile( "max.bf16x2 %0, %1, %2;\n" : "=r"(res) : "r"(x), "r"(zero));
|
||||
return res;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ uint16_t float_to_half(float f) {
|
||||
uint16_t h;
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;" : "=h"(h) : "f"(f));
|
||||
return h;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint32_t float2_pack(float a, float b);
|
||||
|
||||
template <>
|
||||
inline __device__ uint32_t float2_pack<__half>(float a, float b) {
|
||||
__half2 result = __floats2half2_rn(a, b);
|
||||
return reinterpret_cast<uint32_t(&)>(result);
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ uint32_t float2_pack<__nv_bfloat16>(float a, float b) {
|
||||
__nv_bfloat162 result = __floats2bfloat162_rn(a, b);
|
||||
return reinterpret_cast<uint32_t(&)>(result);
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ uint2 float4_pack(float x, float y, float z, float w) {
|
||||
uint2 d;
|
||||
d.x = float2_pack<T>(x, y);
|
||||
d.y = float2_pack<T>(z, w);
|
||||
return d;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
inline __device__ float2 half2_unpack(uint32_t a);
|
||||
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__half>(uint32_t a) {
|
||||
return __half22float2(reinterpret_cast<__half2 (&)>(a));
|
||||
}
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
template <>
|
||||
inline __device__ float2 half2_unpack<__nv_bfloat16>(uint32_t a) {
|
||||
return __bfloat1622float2(reinterpret_cast<__nv_bfloat162 (&)>(a));
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Convert two half2's or bf162's into float, then take their dot product.
|
||||
template <typename T>
|
||||
inline __device__ float hfma2_to_float(const uint32_t a, const uint32_t b) {
|
||||
float2 af = fmha::half2_unpack<T>(a);
|
||||
float2 bf = fmha::half2_unpack<T>(b);
|
||||
return af.x * bf.x + af.y * bf.y;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Converted two vectors of 8 half's or bf16's into float, then take their dot product.
|
||||
template<typename T>
|
||||
inline __device__ float hmulsum8(const uint4 a, const uint4 b) {
|
||||
float sum;
|
||||
sum = fmha::hfma2_to_float<T>(a.x, b.x);
|
||||
sum += fmha::hfma2_to_float<T>(a.y, b.y);
|
||||
sum += fmha::hfma2_to_float<T>(a.z, b.z);
|
||||
sum += fmha::hfma2_to_float<T>(a.w, b.w);
|
||||
return sum;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// L D G
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void ldg(uint8_t &dst, const void *ptr) {
|
||||
dst = *reinterpret_cast<const uint8_t*>(ptr);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void ldg(uint16_t &dst, const void *ptr) {
|
||||
dst = *reinterpret_cast<const uint16_t*>(ptr);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void ldg(uint32_t &dst, const void *ptr) {
|
||||
dst = *reinterpret_cast<const uint32_t*>(ptr);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void ldg(uint2 &dst, const void *ptr) {
|
||||
dst = *reinterpret_cast<const uint2*>(ptr);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void ldg(uint4 &dst, const void *ptr) {
|
||||
dst = *reinterpret_cast<const uint4*>(ptr);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//
|
||||
// S T G
|
||||
//
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void stg(void *ptr, uint8_t val) {
|
||||
*reinterpret_cast<uint8_t*>(ptr) = val;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void stg(void *ptr, uint16_t val) {
|
||||
*reinterpret_cast<uint16_t*>(ptr) = val;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void stg(void *ptr, uint32_t val) {
|
||||
*reinterpret_cast<uint32_t*>(ptr) = val;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void stg(void *ptr, uint2 val) {
|
||||
*reinterpret_cast<uint2*>(ptr) = val;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
inline __device__ void stg(void *ptr, uint4 val) {
|
||||
*reinterpret_cast<uint4*>(ptr) = val;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct MaxOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MaxOp<float> {
|
||||
// This is slightly faster
|
||||
__device__ inline float operator()(float const &x, float const &y) { return max(x, y); }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Operator, int M>
|
||||
__device__ inline void quad_reduce(float (&dst)[M], float (&src)[M], Operator &op) {
|
||||
#pragma unroll
|
||||
for(int mi=0; mi < M; mi++){
|
||||
dst[mi] = src[mi];
|
||||
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 2));
|
||||
dst[mi] = op(dst[mi], __shfl_down_sync(uint32_t(-1), dst[mi], 1));
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Operator, int M>
|
||||
__device__ inline void quad_reduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
|
||||
float tmp[M];
|
||||
#pragma unroll
|
||||
for(int mi=0; mi < M; mi++){
|
||||
tmp[mi] = op(src[mi].x, src[mi].y);
|
||||
}
|
||||
quad_reduce(dst, tmp, op);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Operator, int M>
|
||||
__device__ inline void quad_allreduce(float (&dst)[M], float (&src)[M], Operator &op) {
|
||||
#pragma unroll
|
||||
for(int mi=0; mi < M; mi++){
|
||||
dst[mi] = src[mi];
|
||||
dst[mi] = Allreduce<4>::run(dst[mi], op);
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename Operator, int M>
|
||||
__device__ inline void quad_allreduce(float (&dst)[M], float2 (&src)[M], Operator &op) {
|
||||
float tmp[M];
|
||||
#pragma unroll
|
||||
for(int mi=0; mi < M; mi++){
|
||||
tmp[mi] = op(src[mi].x, src[mi].y);
|
||||
}
|
||||
quad_allreduce(dst, tmp, op);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace fmha
|
@ -77,6 +77,7 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
|
||||
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
|
||||
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
if(${USE_CUDNN})
|
||||
message(STATUS " cuDNN version : ${CUDNN_VERSION}")
|
||||
endif()
|
||||
|
2
setup.py
2
setup.py
@ -322,7 +322,7 @@ def get_submodule_folders():
|
||||
git_modules_path = os.path.join(cwd, ".gitmodules")
|
||||
default_modules_path = [os.path.join(third_party_path, name) for name in [
|
||||
"gloo", "cpuinfo", "tbb", "onnx",
|
||||
"foxi", "QNNPACK", "fbgemm"
|
||||
"foxi", "QNNPACK", "fbgemm", "cutlass"
|
||||
]]
|
||||
if not os.path.exists(git_modules_path):
|
||||
return default_modules_path
|
||||
|
1
third_party/cutlass
vendored
Submodule
1
third_party/cutlass
vendored
Submodule
Submodule third_party/cutlass added at b72cbf957d
Reference in New Issue
Block a user