diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 6c095680733f..bbf79491e2d3 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -1,6 +1,7 @@ cmake_minimum_required(VERSION 3.27 FATAL_ERROR) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) + if(NOT MSVC) string(APPEND CMAKE_CXX_FLAGS " -Wno-ignored-qualifiers") string(APPEND CMAKE_C_FLAGS " -Wno-ignored-qualifiers") @@ -195,14 +196,94 @@ if(USE_FLASH_ATTENTION) endif() endif() message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") + + # CK SDPA sources require specific compilation flags + set(CK_SDPA_EXTRA_HIPCC_FLAGS + -fno-autolink + -fhip-new-launch-api + -fgnuc-version=4.2.1 + -fno-implicit-modules + -fskip-odr-check-in-gmf + -fcxx-exceptions + -fexceptions + -fcolor-diagnostics + -faddrsig + -fno-rounding-math + -mconstructor-aliases + -mllvm + -amdgpu-internalize-symbols + -fvisibility=hidden + -Wno-float-equal + -fgpu-flush-denormals-to-zero + -Wno-unused-parameter) + + #TODO: The following flags are specific to 8-bit width types which are not integrated via CK yet. + # Add once that support is integrated + #check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK) + #if(HAS_NO_OFFLOAD_UNIFORM_BLOCK) + # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -fno-offload-uniform-block) + #endif() + #check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION) + #if(HAS_LSR_DROP_SOLUTION) + # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm --lsr-drop-solution=1) + #endif() + #check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED) + #if(HAS_ENABLE_POST_MISCHED) + # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -enable-post-misched=0) + #endif() + #set(check-coerce) + #check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) + #if(check-coerce) + # list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) + #endif() + + list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-early-inline-all=true) + list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-function-calls=false) + + # Additional CK compiler flags + set(CK_SDPA_EXTRA_HIPCC_OPTIONS + CK_ENABLE_BF16 + CK_ENABLE_BF8 + CK_ENABLE_FP16 + CK_ENABLE_FP32 + CK_ENABLE_FP64 + CK_ENABLE_FP8 + CK_ENABLE_INT8 + CK_USE_FNUZ_FP8 + CK_USE_GFX94 + CK_USE_XDL + __HIP_PLATFORM_AMD__=1 + __HIP_PLATFORM_HCC__=1 + CK_TILE_FMHA_FWD_FAST_EXP2=1 + CK_TILE_FMHA_FWD_SPLITKV_API=1 + CK_TILE_FMHA_FWD_APPENDKV_API=1 + CK_TILE_FMHA_FWD_PAGEDKV_API=1 + __GCC_HAVE_DWARF2_CFI_ASM=1 + USE_ROCM_CK_SDPA) + message(STATUS "Generating CK kernel instances...") add_subdirectory(native/transformers/hip/flash_attn/ck) - file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip}) # FAv3 Generation add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) - file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") - list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip}) + file(GLOB ck_sdpa_sources_hip + "native/transformers/hip/flash_attn/ck/*.hip" + "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") + + set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_library(ck_sdpa STATIC + ${ck_sdpa_sources_hip} + HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS}) + set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS}) + target_include_directories(ck_sdpa PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha + ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include + ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck + ) + endif() file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp index 400da17426f1..f4e1ef71f5a9 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp @@ -8,9 +8,9 @@ namespace ck_tile { // Added by hipification to become a no-op on non supported architectures -template +template #if CK_TILE_USE_LAUNCH_BOUNDS -__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) +__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif __global__ void kentry_pt(Args... args) { @@ -29,14 +29,13 @@ __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu) // // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // -template CK_TILE_HOST auto make_kernel_pt(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) { - const auto kernel = kentry_pt; + const auto kernel = kentry_pt; return [=](const stream_config& s) { kernel<<>>(args...); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 4cd773bc1612..9c75baa0bf94 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1762,6 +1762,10 @@ if(USE_ROCM) target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) + if(USE_ROCM_CK_SDPA) + target_link_libraries(torch_hip PRIVATE ck_sdpa) + endif() + if(USE_FBGEMM_GENAI) if(USE_ROCM) target_link_libraries(torch_hip PRIVATE fbgemm_genai) diff --git a/third_party/composable_kernel b/third_party/composable_kernel index 7fe50dc3da20..de61e5549382 160000 --- a/third_party/composable_kernel +++ b/third_party/composable_kernel @@ -1 +1 @@ -Subproject commit 7fe50dc3da2069d6645d9deb8c017a876472a977 +Subproject commit de61e554938265a5d17a1bba8c148457125e80cd