Revert "[ROCm] Add specific compile options for CK SDPA (#161759)"

This reverts commit d22d916719eb7daff8455a01d216d65f81899a9e.

Reverted https://github.com/pytorch/pytorch/pull/161759 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this seems to break internal ROCm jobs ([comment](https://github.com/pytorch/pytorch/pull/161759#issuecomment-3272807726))
This commit is contained in:
PyTorch MergeBot
2025-09-10 00:44:29 +00:00
parent 33589374b6
commit 2281d009e5
4 changed files with 10 additions and 94 deletions

View File

@ -1,7 +1,6 @@
cmake_minimum_required(VERSION 3.27 FATAL_ERROR) cmake_minimum_required(VERSION 3.27 FATAL_ERROR)
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH}) set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
if(NOT MSVC) if(NOT MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wno-ignored-qualifiers") string(APPEND CMAKE_CXX_FLAGS " -Wno-ignored-qualifiers")
string(APPEND CMAKE_C_FLAGS " -Wno-ignored-qualifiers") string(APPEND CMAKE_C_FLAGS " -Wno-ignored-qualifiers")
@ -196,94 +195,14 @@ if(USE_FLASH_ATTENTION)
endif() endif()
endif() endif()
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled") message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
# CK SDPA sources require specific compilation flags
set(CK_SDPA_EXTRA_HIPCC_FLAGS
-fno-autolink
-fhip-new-launch-api
-fgnuc-version=4.2.1
-fno-implicit-modules
-fskip-odr-check-in-gmf
-fcxx-exceptions
-fexceptions
-fcolor-diagnostics
-faddrsig
-fno-rounding-math
-mconstructor-aliases
-mllvm
-amdgpu-internalize-symbols
-fvisibility=hidden
-Wno-float-equal
-fgpu-flush-denormals-to-zero
-Wno-unused-parameter)
#TODO: The following flags are specific to 8-bit width types which are not integrated via CK yet.
# Add once that support is integrated
#check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK)
#if(HAS_NO_OFFLOAD_UNIFORM_BLOCK)
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -fno-offload-uniform-block)
#endif()
#check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION)
#if(HAS_LSR_DROP_SOLUTION)
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm --lsr-drop-solution=1)
#endif()
#check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
#if(HAS_ENABLE_POST_MISCHED)
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -enable-post-misched=0)
#endif()
#set(check-coerce)
#check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
#if(check-coerce)
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
#endif()
list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-early-inline-all=true)
list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-function-calls=false)
# Additional CK compiler flags
set(CK_SDPA_EXTRA_HIPCC_OPTIONS
CK_ENABLE_BF16
CK_ENABLE_BF8
CK_ENABLE_FP16
CK_ENABLE_FP32
CK_ENABLE_FP64
CK_ENABLE_FP8
CK_ENABLE_INT8
CK_USE_FNUZ_FP8
CK_USE_GFX94
CK_USE_XDL
__HIP_PLATFORM_AMD__=1
__HIP_PLATFORM_HCC__=1
CK_TILE_FMHA_FWD_FAST_EXP2=1
CK_TILE_FMHA_FWD_SPLITKV_API=1
CK_TILE_FMHA_FWD_APPENDKV_API=1
CK_TILE_FMHA_FWD_PAGEDKV_API=1
__GCC_HAVE_DWARF2_CFI_ASM=1
USE_ROCM_CK_SDPA)
message(STATUS "Generating CK kernel instances...") message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck) add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
# FAv3 Generation # FAv3 Generation
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3) add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
file(GLOB ck_sdpa_sources_hip file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
"native/transformers/hip/flash_attn/ck/*.hip" list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
"native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_add_library(ck_sdpa STATIC
${ck_sdpa_sources_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS})
set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS})
target_include_directories(ck_sdpa PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha
${CMAKE_CURRENT_BINARY_DIR}/composable_kernel
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include
${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck
)
endif() endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip") file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip") file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")

View File

@ -8,9 +8,9 @@
namespace ck_tile { namespace ck_tile {
// Added by hipification to become a no-op on non supported architectures // Added by hipification to become a no-op on non supported architectures
template <int MinBlockPerCu, typename Kernel, typename... Args> template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
#if CK_TILE_USE_LAUNCH_BOUNDS #if CK_TILE_USE_LAUNCH_BOUNDS
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) __launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
#endif #endif
__global__ void kentry_pt(Args... args) __global__ void kentry_pt(Args... args)
{ {
@ -29,13 +29,14 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
// //
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl // the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
// //
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU, template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
typename KernelImpl, typename KernelImpl,
typename... Args> typename... Args>
CK_TILE_HOST auto CK_TILE_HOST auto
make_kernel_pt(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args) make_kernel_pt(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
{ {
const auto kernel = kentry_pt<MinBlockPerCu, KernelImpl, Args...>; const auto kernel = kentry_pt<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
return [=](const stream_config& s) { return [=](const stream_config& s) {
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);

View File

@ -1762,10 +1762,6 @@ if(USE_ROCM)
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS}) target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
if(USE_ROCM_CK_SDPA)
target_link_libraries(torch_hip PRIVATE ck_sdpa)
endif()
if(USE_FBGEMM_GENAI) if(USE_FBGEMM_GENAI)
if(USE_ROCM) if(USE_ROCM)
target_link_libraries(torch_hip PRIVATE fbgemm_genai) target_link_libraries(torch_hip PRIVATE fbgemm_genai)