diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index bbf79491e2d3..6c095680733f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -1,7 +1,6 @@ cmake_minimum_required(VERSION 3.27 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) - if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -Wno-ignored-qualifiers") string(APPEND CMAKE_C_FLAGS " -Wno-ignored-qualifiers") @@ -196,94 +195,14 @@ if(USE_FLASH_ATTENTION) endif() endif() message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") - - # CK SDPA sources require specific compilation flags - set(CK_SDPA_EXTRA_HIPCC_FLAGS - -fno-autolink - -fhip-new-launch-api - -fgnuc-version=4.2.1 - -fno-implicit-modules - -fskip-odr-check-in-gmf - -fcxx-exceptions - -fexceptions - -fcolor-diagnostics - -faddrsig - -fno-rounding-math - -mconstructor-aliases - -mllvm - -amdgpu-internalize-symbols - -fvisibility=hidden - -Wno-float-equal - -fgpu-flush-denormals-to-zero - -Wno-unused-parameter) - - #TODO: The following flags are specific to 8-bit width types which are not integrated via CK yet. - # Add once that support is integrated - #check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) - #if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) - # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -fno-offload-uniform-block) - #endif() - #check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION) - #if(HAS_LSR_DROP_SOLUTION) - # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm --lsr-drop-solution=1) - #endif() - #check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) - #if(HAS_ENABLE_POST_MISCHED) - # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -enable-post-misched=0) - #endif() - #set(check-coerce) - #check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) - #if(check-coerce) - # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) - #endif() - - list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-early-inline-all=true) - list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-function-calls=false) - - # Additional CK compiler flags - set(CK_SDPA_EXTRA_HIPCC_OPTIONS - CK_ENABLE_BF16 - CK_ENABLE_BF8 - CK_ENABLE_FP16 - CK_ENABLE_FP32 - CK_ENABLE_FP64 - CK_ENABLE_FP8 - CK_ENABLE_INT8 - CK_USE_FNUZ_FP8 - CK_USE_GFX94 - CK_USE_XDL - __HIP_PLATFORM_AMD__=1 - __HIP_PLATFORM_HCC__=1 - CK_TILE_FMHA_FWD_FAST_EXP2=1 - CK_TILE_FMHA_FWD_SPLITKV_API=1 - CK_TILE_FMHA_FWD_APPENDKV_API=1 - CK_TILE_FMHA_FWD_PAGEDKV_API=1 - __GCC_HAVE_DWARF2_CFI_ASM=1 - USE_ROCM_CK_SDPA) - message(STATUS "Generating CK kernel instances...") add_subdirectory(native/transformers/hip/flash_attn/ck) + file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) # FAv3 Generation add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) - file(GLOB ck_sdpa_sources_hip - "native/transformers/hip/flash_attn/ck/*.hip" - "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") - - set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) - hip_add_library(ck_sdpa STATIC - ${ck_sdpa_sources_hip} - HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS}) - set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON) - target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS}) - target_include_directories(ck_sdpa PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include - ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha - ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel - ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include - ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck - ) - + file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp index f4e1ef71f5a9..400da17426f1 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp @@ -8,9 +8,9 @@ namespace ck_tile { // Added by hipification to become a no-op on non supported architectures -template +template #if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) +__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) #endif __global__ void kentry_pt(Args... args) { @@ -29,13 +29,14 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -template CK_TILE_HOST auto make_kernel_pt(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { - const auto kernel = kentry_pt; + const auto kernel = kentry_pt; return [=](const stream_config& s) { kernel<<>>(args...); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 9c75baa0bf94..4cd773bc1612 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1762,10 +1762,6 @@ if(USE_ROCM) target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) - if(USE_ROCM_CK_SDPA) - target_link_libraries(torch_hip PRIVATE ck_sdpa) - endif() - if(USE_FBGEMM_GENAI) if(USE_ROCM) target_link_libraries(torch_hip PRIVATE fbgemm_genai) diff --git a/third_party/composable_kernel b/third_party/composable_kernel index de61e5549382..7fe50dc3da20 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit de61e554938265a5d17a1bba8c148457125e80cd +Subproject commit 7fe50dc3da2069d6645d9deb8c017a876472a977