mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add Independent Memory Efficient and Flash Attention Build Flags (#107985)
# Summary In an effort to simplify https://github.com/pytorch/pytorch/pull/105602, this PR pulls out independent chunks of code that can be landed prior to FlashV2 landing. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107985 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
f0c6e5c91f
commit
182a9cf366
@ -727,6 +727,12 @@ cmake_dependent_option(
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention" ON
|
||||
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
|
||||
# Flash Attention2 will error while building for sm52 while Mem Eff Attention won't
|
||||
cmake_dependent_option(
|
||||
USE_MEM_EFF_ATTENTION
|
||||
"Enable memory-efficient attention for scaled dot product attention" ON
|
||||
"USE_CUDA AND NOT ROCM AND NOT CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6" OFF)
|
||||
|
||||
if(DEBUG_CUDA)
|
||||
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
|
||||
|
@ -170,7 +170,9 @@ file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
||||
endif()
|
||||
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cu ${mem_eff_attention_cuda_kernels_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${mem_eff_attention_cuda_cpp})
|
||||
|
@ -38,6 +38,8 @@
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
// FlashAttention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassF.h>
|
||||
@ -830,7 +832,7 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||
c10::optional<double> scale,
|
||||
const c10::optional<at::Tensor>& causal_diagonal,
|
||||
const c10::optional<at::Tensor>& seqlen_k) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
#if defined(USE_MEM_EFF_ATTENTION)
|
||||
// TODO In theory it is possible to compile with _CUDA_ARCH < 5.0 and run on a
|
||||
// machine that is >= 5.0. In practice, this is not a problem but since
|
||||
// this would avoid runtime architecture checks, we should look into it
|
||||
@ -1097,7 +1099,7 @@ std::tuple<at::Tensor, at::Tensor, Tensor, Tensor> _efficient_attention_forward(
|
||||
std::move(seed_t),
|
||||
std::move(offset_t));
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
|
||||
@ -1108,7 +1110,7 @@ Tensor triton_scaled_dot_attention(const Tensor& q, const Tensor& k, const Tenso
|
||||
|
||||
REGISTER_CUDA_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_cuda);
|
||||
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
namespace {
|
||||
/**
|
||||
* simple kernel that populates a tensor with rand uniform values.
|
||||
@ -1174,7 +1176,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
|
||||
const int64_t n_heads = self.size(1);
|
||||
const int64_t n_queries = self.size(2);
|
||||
const int64_t n_keys = self.size(3);
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
#if defined(USE_MEM_EFF_ATTENTION)
|
||||
|
||||
at::PhiloxCudaState rng_engine_inputs;
|
||||
rng_engine_inputs = at::PhiloxCudaState(seed, offset);
|
||||
@ -1192,7 +1194,7 @@ at::Tensor& _fill_mem_eff_dropout_mask_(
|
||||
|
||||
return self;
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||
return self;
|
||||
}
|
||||
|
||||
|
@ -19,6 +19,8 @@
|
||||
#ifdef USE_FLASH_ATTENTION
|
||||
// FlashAttention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/flash_attn/fmha_api.h>
|
||||
#endif
|
||||
#ifdef USE_MEM_EFF_ATTENTION
|
||||
// MemoryEfficient Attention Specific Imports
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h>
|
||||
#include <ATen/native/transformers/cuda/mem_eff_attention/kernels/cutlassB.h>
|
||||
@ -118,7 +120,7 @@ _efficient_attention_backward(
|
||||
const bool bias_requires_grad,
|
||||
const c10::optional<double> scale,
|
||||
c10::optional <int64_t> num_splits_key) {
|
||||
#if defined(USE_FLASH_ATTENTION)
|
||||
#if defined(USE_MEM_EFF_ATTENTION)
|
||||
if (!grad_out_.defined()) {
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
@ -473,7 +475,7 @@ _efficient_attention_backward(
|
||||
AT_CUDA_CHECK(cudaGetLastError());
|
||||
return std::make_tuple(grad_q, grad_k, grad_v, grad_bias);
|
||||
#endif
|
||||
TORCH_CHECK(false, "USE_FLASH_ATTENTION was not enabled for build.")
|
||||
TORCH_CHECK(false, "USE_MEM_EFF_ATTENTION was not enabled for build.")
|
||||
return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{});
|
||||
}
|
||||
|
||||
|
@ -284,8 +284,8 @@ bool use_flash_attention(sdp_params params, bool debug) {
|
||||
}
|
||||
|
||||
bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
TORCH_CHECK(!debug, "Torch was not compiled with flash attention.");
|
||||
#ifndef USE_MEM_EFF_ATTENTION
|
||||
TORCH_CHECK(!debug, "Torch was not compiled with memory efficient attention.");
|
||||
return false;
|
||||
#endif
|
||||
// Constraints specific to mem efficient attention
|
||||
|
@ -972,6 +972,9 @@ elseif(USE_CUDA)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
target_compile_definitions(torch_cuda PRIVATE USE_MEM_EFF_ATTENTION)
|
||||
endif()
|
||||
if(BUILD_LAZY_CUDA_LINALG)
|
||||
add_library(torch_cuda_linalg ${ATen_CUDA_LINALG_SRCS})
|
||||
target_compile_definitions(torch_cuda_linalg PRIVATE USE_CUDA BUILD_LAZY_CUDA_LINALG)
|
||||
|
@ -80,6 +80,7 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_CUSPARSELT : ${USE_CUSPARSELT}")
|
||||
message(STATUS " CUDA version : ${CUDA_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||
if(${USE_CUDNN})
|
||||
message(STATUS " cuDNN version : ${CUDNN_VERSION}")
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user