[ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#133866)

Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly

Fixes #125230

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133866
Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
Xinya Zhang
2024-08-27 18:24:27 +00:00
committed by PyTorch MergeBot
parent 5b392d22c6
commit 5fd670e0ef
3 changed files with 21 additions and 4 deletions

View File

@ -883,6 +883,16 @@ cmake_dependent_option(
Will be disabled if not supported by the platform" ON
"USE_CUDA OR USE_ROCM" OFF)
#
# Cannot be put into Dependencies.cmake due circular dependency:
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()
if(DEBUG_CUDA)
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")

View File

@ -25,7 +25,10 @@
#include <c10/util/string_view.h>
#if USE_ROCM
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
#include <aotriton/flash.h>
#define USE_AOTRITON 1
#endif
#endif
/**
@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
return false;
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm80, sm90>(dprops)) {
@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
#if USE_AOTRITON
auto stream = at::cuda::getCurrentCUDAStream().stream();
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
return false;
}
#else
return false;
#endif
#else
auto dprops = at::cuda::getCurrentDeviceProperties();
if (!check_sm_version<sm50, sm90>(dprops)) {

View File

@ -1103,10 +1103,6 @@ if(USE_ROCM)
message(STATUS "Disabling Kernel Assert for ROCm")
endif()
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
if(USE_CUDA)
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
endif()
else()
caffe2_update_option(USE_ROCM OFF)
endif()