mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[build] Create target for flash attention (#156235)"
This reverts commit 6d02321472ee0761092166dd273eb3ec386cf0c0.
Reverted https://github.com/pytorch/pytorch/pull/156235 on behalf of https://github.com/ZainRizvi due to Weird, but seems to have broken trunk: test_jit_fuser_te.py::TestTEFuserDynamic::test_skip_grad_in_check [GH job link](https://github.com/pytorch/pytorch/actions/runs/15748768079/job/44390494621) [HUD commit link](6d02321472
) ([comment](https://github.com/pytorch/pytorch/pull/156235#issuecomment-2987784207))
This commit is contained in:
@ -169,10 +169,14 @@ file(GLOB native_transformers_hip_hip "native/transformers/hip/*.hip")
|
||||
file(GLOB native_transformers_hip_cpp "native/transformers/hip/*.cpp")
|
||||
file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
|
||||
file(GLOB native_utils_cpp "native/utils/*.cpp")
|
||||
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
|
||||
file(GLOB flash_attention_cuda_cpp ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp)
|
||||
file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_api.cpp")
|
||||
|
||||
# flash_attention sources
|
||||
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
|
||||
# Flash attention C++ sources
|
||||
file(GLOB flash_attention_cuda_cpp
|
||||
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
|
||||
"native/transformers/cuda/flash_attn/flash_api.cpp"
|
||||
)
|
||||
|
||||
# flash_attention hip sources
|
||||
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
|
||||
@ -204,29 +208,10 @@ file(GLOB mem_eff_attention_cuda_cu "native/transformers/cuda/mem_eff_attention/
|
||||
file(GLOB mem_eff_attention_cuda_kernels_cu "native/transformers/cuda/mem_eff_attention/kernels/*.cu")
|
||||
file(GLOB mem_eff_attention_cuda_cpp "native/transformers/cuda/mem_eff_attention/*.cpp")
|
||||
|
||||
if(USE_CUDA AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
|
||||
add_library(flash_attention OBJECT EXCLUDE_FROM_ALL ${flash_attention_cuda_kernels_cu} ${flash_attention_cuda_cpp})
|
||||
|
||||
target_include_directories(flash_attention PUBLIC
|
||||
${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc
|
||||
${PROJECT_SOURCE_DIR}/third_party/flash-attention/include
|
||||
${PROJECT_SOURCE_DIR}/third_party/cutlass/include
|
||||
${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src
|
||||
)
|
||||
|
||||
target_compile_definitions(flash_attention PRIVATE
|
||||
# Copied from https://github.com/pytorch/pytorch/blob/a10024d7dea47c52469059a47efe376eb20adca0/caffe2/CMakeLists.txt#L1431
|
||||
FLASH_NAMESPACE=pytorch_flash
|
||||
FLASHATTENTION_DISABLE_ALIBI
|
||||
FLASHATTENTION_DISABLE_SOFTCAP
|
||||
UNFUSE_FMA
|
||||
)
|
||||
|
||||
set_target_properties(flash_attention PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
endif()
|
||||
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND native_transformers_cuda_cpp ${native_flash_attn_api_cpp})
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_cu})
|
||||
list(APPEND native_transformers_cuda_cu ${flash_attention_cuda_kernels_cu})
|
||||
list(APPEND native_transformers_cuda_cpp ${flash_attention_cuda_cpp})
|
||||
list(APPEND FLASH_ATTENTION_CUDA_SOURCES ${flash_attention_cuda_cu} ${flash_attention_cuda_kernels_cu})
|
||||
list(APPEND ATen_ATTENTION_KERNEL_SRCS ${flash_attention_cuda_kernels_cu})
|
||||
|
||||
|
@ -1044,13 +1044,8 @@ elseif(USE_CUDA)
|
||||
FLASH_NAMESPACE=pytorch_flash
|
||||
UNFUSE_FMA # Addressing issue #121558
|
||||
)
|
||||
target_sources(torch_cuda PRIVATE $<TARGET_OBJECTS:flash_attention>)
|
||||
target_include_directories(torch_cuda PUBLIC
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/flash-attention/include>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/cutlass/include>
|
||||
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src>
|
||||
$<INSTALL_INTERFACE:include>
|
||||
target_include_directories(torch_cuda PRIVATE
|
||||
${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/
|
||||
)
|
||||
endif()
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
|
Reference in New Issue
Block a user