mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCM] Properly disable Flash Attention/Efficient Attention with environment variables (#133866)
Now `USE_FLASH_ATTENTION=0 USE_MEM_EFF_ATTENTION=0 python setup.py` can compile correctly Fixes #125230 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133866 Approved by: https://github.com/jithunnair-amd, https://github.com/jeffdaily, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
5b392d22c6
commit
5fd670e0ef
@ -883,6 +883,16 @@ cmake_dependent_option(
|
||||
Will be disabled if not supported by the platform" ON
|
||||
"USE_CUDA OR USE_ROCM" OFF)
|
||||
|
||||
#
|
||||
# Cannot be put into Dependencies.cmake due circular dependency:
|
||||
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
||||
#
|
||||
if(USE_ROCM)
|
||||
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
|
||||
include(cmake/External/aotriton.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(DEBUG_CUDA)
|
||||
string(APPEND CMAKE_CUDA_FLAGS_DEBUG " -lineinfo")
|
||||
string(APPEND CMAKE_CUDA_FLAGS_RELWITHDEBINFO " -lineinfo")
|
||||
|
@ -25,7 +25,10 @@
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
#if USE_ROCM
|
||||
#if defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION)
|
||||
#include <aotriton/flash.h>
|
||||
#define USE_AOTRITON 1
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/**
|
||||
@ -208,6 +211,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
#if USE_ROCM
|
||||
#if USE_AOTRITON
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
@ -217,6 +221,9 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm80, sm90>(dprops)) {
|
||||
@ -239,6 +246,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
using sm50 = SMVersion<5, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
#if USE_ROCM
|
||||
#if USE_AOTRITON
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
if (hipSuccess != aotriton::v2::flash::check_gpu(stream)) {
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
@ -248,6 +256,9 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
}
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
#else
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
if (!check_sm_version<sm50, sm90>(dprops)) {
|
||||
|
@ -1103,10 +1103,6 @@ if(USE_ROCM)
|
||||
message(STATUS "Disabling Kernel Assert for ROCm")
|
||||
endif()
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/External/aotriton.cmake)
|
||||
if(USE_CUDA)
|
||||
caffe2_update_option(USE_MEM_EFF_ATTENTION OFF)
|
||||
endif()
|
||||
else()
|
||||
caffe2_update_option(USE_ROCM OFF)
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user