Add Independent Memory Efficient and Flash Attention Build Flags (#107985)

# Summary
In an effort to simplify https://github.com/pytorch/pytorch/pull/105602, this PR pulls out independent chunks of code that can be landed prior to FlashV2 landing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107985
Approved by: https://github.com/cpuhrsch
This commit is contained in:
drisspg
2023-08-28 18:39:15 +00:00
committed by PyTorch MergeBot
parent f0c6e5c91f
commit 182a9cf366
8 changed files with 28 additions and 9 deletions

View File

@ -727,6 +727,12 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
cmake_dependent_option(
USE_MEM_EFF_ATTENTION
"Enable memory-efficient attention for scaled dot product attention" ON
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

View File

@ -170,7 +170,9 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention
if(USE_FLASH_ATTENTION)
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
endif()
if(USE_MEM_EFF_ATTENTION)
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_cu})
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_kernels_cu})
list(APPEND native_transformers_cuda_cpp ${mem_eff_attention_cuda_cpp})

View File

@ -38,6 +38,8 @@
#ifdef USE_FLASH_ATTENTION
// FlashAttention Specific Imports
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
#endif
#ifdef USE_MEM_EFF_ATTENTION
// MemoryEfficient Attention Specific Imports
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
@ -830,7 +832,7 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
c10::optional<double> scale,
const c10::optional<at::Tensor>& causal_diagonal,
const c10::optional<at::Tensor>& seqlen_k) {
#if defined(USE_FLASH_ATTENTION)
#if defined(USE_MEM_EFF_ATTENTION)
// TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
// machine that is >= 5.0. In practice, this is not a problem but since
// this would avoid runtime architecture checks, we should look into it
@ -1097,7 +1099,7 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
std::move(seed_t),
std::move(offset_t));
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}
@ -1108,7 +1110,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
#ifdef USE_FLASH_ATTENTION
#ifdef USE_MEM_EFF_ATTENTION
namespace {
/**
* simple kernel that populates a tensor with rand uniform values.
@ -1174,7 +1176,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
const int64_t n_heads = self.size(1);
const int64_t n_queries = self.size(2);
const int64_t n_keys = self.size(3);
#if defined(USE_FLASH_ATTENTION)
#if defined(USE_MEM_EFF_ATTENTION)
at::PhiloxCudaState rng_engine_inputs;
rng_engine_inputs = at::PhiloxCudaState(seed, offset);
@ -1192,7 +1194,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
return self;
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return self;
}

View File

@ -19,6 +19,8 @@
#ifdef USE_FLASH_ATTENTION
// FlashAttention Specific Imports
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
#endif
#ifdef USE_MEM_EFF_ATTENTION
// MemoryEfficient Attention Specific Imports
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h>
@ -118,7 +120,7 @@ _efficient_attention_backward(
const bool bias_requires_grad,
const c10::optional<double> scale,
c10::optional <int64_t> num_splits_key) {
#if defined(USE_FLASH_ATTENTION)
#if defined(USE_MEM_EFF_ATTENTION)
if (!grad_out_.defined()) {
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}
@ -473,7 +475,7 @@ _efficient_attention_backward(
AT_CUDA_CHECK(cudaGetLastError());
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
#endif
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
}

View File

@ -284,8 +284,8 @@ bool use_flash_attention(sdp_params params, bool debug) {
}
bool use_mem_efficient_attention(sdp_params params, bool debug) {
#ifndef USE_FLASH_ATTENTION
TORCH_CHECK(!debug, "Torch was not compiled with flash attention.");
#ifndef USE_MEM_EFF_ATTENTION
TORCH_CHECK(!debug, "Torch was not compiled with memory efficient attention.");
return false;
#endif
// Constraints specific to mem efficient attention

View File

@ -972,6 +972,9 @@ elseif(USE_CUDA)
if(USE_FLASH_ATTENTION)
target_compile_definitions(torch_cuda PRIVATE USE_FLASH_ATTENTION)
endif()
if(USE_MEM_EFF_ATTENTION)
target_compile_definitions(torch_cuda PRIVATE USE_MEM_EFF_ATTENTION)
endif()
if(BUILD_LAZY_CUDA_LINALG)
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)

View File

@ -80,6 +80,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
if(${USE_CUDNN})
message(STATUS " cuDNN version : ${CUDNN_VERSION}")
endif()

View File

@ -106,6 +106,9 @@
# USE_FLASH_ATTENTION=0
# disables building flash attention for scaled dot product attention
#
# USE_MEM_EFF_ATTENTION=0
# disables building memory efficient attention for scaled dot product attention
#
# USE_LEVELDB
# enables use of LevelDB for storage
#