mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[ROCm] Add specific compile options for CK SDPA (#161759)"
This reverts commit d22d916719eb7daff8455a01d216d65f81899a9e. Reverted https://github.com/pytorch/pytorch/pull/161759 on behalf of https://github.com/huydhn due to Sorry for reverting your change but this seems to break internal ROCm jobs ([comment](https://github.com/pytorch/pytorch/pull/161759#issuecomment-3272807726))
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
cmake_minimum_required(VERSION 3.27 FATAL_ERROR)
|
||||
set(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
|
||||
|
||||
|
||||
if(NOT MSVC)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -Wno-ignored-qualifiers")
|
||||
string(APPEND CMAKE_C_FLAGS " -Wno-ignored-qualifiers")
|
||||
@ -196,94 +195,14 @@ if(USE_FLASH_ATTENTION)
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
|
||||
|
||||
# CK SDPA sources require specific compilation flags
|
||||
set(CK_SDPA_EXTRA_HIPCC_FLAGS
|
||||
-fno-autolink
|
||||
-fhip-new-launch-api
|
||||
-fgnuc-version=4.2.1
|
||||
-fno-implicit-modules
|
||||
-fskip-odr-check-in-gmf
|
||||
-fcxx-exceptions
|
||||
-fexceptions
|
||||
-fcolor-diagnostics
|
||||
-faddrsig
|
||||
-fno-rounding-math
|
||||
-mconstructor-aliases
|
||||
-mllvm
|
||||
-amdgpu-internalize-symbols
|
||||
-fvisibility=hidden
|
||||
-Wno-float-equal
|
||||
-fgpu-flush-denormals-to-zero
|
||||
-Wno-unused-parameter)
|
||||
|
||||
#TODO: The following flags are specific to 8-bit width types which are not integrated via CK yet.
|
||||
# Add once that support is integrated
|
||||
#check_cxx_compiler_flag("-fno-offload-uniform-block" HAS_NO_OFFLOAD_UNIFORM_BLOCK)
|
||||
#if(HAS_NO_OFFLOAD_UNIFORM_BLOCK)
|
||||
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -fno-offload-uniform-block)
|
||||
#endif()
|
||||
#check_cxx_compiler_flag("-mllvm --lsr-drop-solution=1" HAS_LSR_DROP_SOLUTION)
|
||||
#if(HAS_LSR_DROP_SOLUTION)
|
||||
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm --lsr-drop-solution=1)
|
||||
#endif()
|
||||
#check_cxx_compiler_flag("-mllvm -enable-post-misched=0" HAS_ENABLE_POST_MISCHED)
|
||||
#if(HAS_ENABLE_POST_MISCHED)
|
||||
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -enable-post-misched=0)
|
||||
#endif()
|
||||
#set(check-coerce)
|
||||
#check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce)
|
||||
#if(check-coerce)
|
||||
# list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
|
||||
#endif()
|
||||
|
||||
list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-early-inline-all=true)
|
||||
list(APPEND CK_SDPA_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-function-calls=false)
|
||||
|
||||
# Additional CK compiler flags
|
||||
set(CK_SDPA_EXTRA_HIPCC_OPTIONS
|
||||
CK_ENABLE_BF16
|
||||
CK_ENABLE_BF8
|
||||
CK_ENABLE_FP16
|
||||
CK_ENABLE_FP32
|
||||
CK_ENABLE_FP64
|
||||
CK_ENABLE_FP8
|
||||
CK_ENABLE_INT8
|
||||
CK_USE_FNUZ_FP8
|
||||
CK_USE_GFX94
|
||||
CK_USE_XDL
|
||||
__HIP_PLATFORM_AMD__=1
|
||||
__HIP_PLATFORM_HCC__=1
|
||||
CK_TILE_FMHA_FWD_FAST_EXP2=1
|
||||
CK_TILE_FMHA_FWD_SPLITKV_API=1
|
||||
CK_TILE_FMHA_FWD_APPENDKV_API=1
|
||||
CK_TILE_FMHA_FWD_PAGEDKV_API=1
|
||||
__GCC_HAVE_DWARF2_CFI_ASM=1
|
||||
USE_ROCM_CK_SDPA)
|
||||
|
||||
message(STATUS "Generating CK kernel instances...")
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck)
|
||||
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
|
||||
# FAv3 Generation
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
|
||||
file(GLOB ck_sdpa_sources_hip
|
||||
"native/transformers/hip/flash_attn/ck/*.hip"
|
||||
"native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
|
||||
|
||||
set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
|
||||
hip_add_library(ck_sdpa STATIC
|
||||
${ck_sdpa_sources_hip}
|
||||
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS})
|
||||
set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS})
|
||||
target_include_directories(ck_sdpa PUBLIC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha
|
||||
${CMAKE_CURRENT_BINARY_DIR}/composable_kernel
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck
|
||||
)
|
||||
|
||||
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
|
||||
endif()
|
||||
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
|
||||
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
|
||||
|
@ -8,9 +8,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
// Added by hipification to become a no-op on non supported architectures
|
||||
template <int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
template <int MaxThreadPerBlock, int MinBlockPerCu, typename Kernel, typename... Args>
|
||||
#if CK_TILE_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
__launch_bounds__(MaxThreadPerBlock, MinBlockPerCu)
|
||||
#endif
|
||||
__global__ void kentry_pt(Args... args)
|
||||
{
|
||||
@ -29,13 +29,14 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu)
|
||||
//
|
||||
// the "static __device__ operator()(some_arg)" is the entry point of KernelImpl
|
||||
//
|
||||
template <int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
template <int MaxThreadPerBlock = CK_TILE_MAX_THREAD_PER_BLOCK,
|
||||
int MinBlockPerCu = CK_TILE_MIN_BLOCK_PER_CU,
|
||||
typename KernelImpl,
|
||||
typename... Args>
|
||||
CK_TILE_HOST auto
|
||||
make_kernel_pt(KernelImpl /*f*/, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
|
||||
{
|
||||
const auto kernel = kentry_pt<MinBlockPerCu, KernelImpl, Args...>;
|
||||
const auto kernel = kentry_pt<MaxThreadPerBlock, MinBlockPerCu, KernelImpl, Args...>;
|
||||
|
||||
return [=](const stream_config& s) {
|
||||
kernel<<<grid_dim, block_dim, lds_byte, s.stream_id_>>>(args...);
|
||||
|
@ -1762,10 +1762,6 @@ if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PUBLIC torch_cpu_library ${Caffe2_PUBLIC_HIP_DEPENDENCY_LIBS})
|
||||
target_link_libraries(torch_hip PRIVATE ${Caffe2_HIP_DEPENDENCY_LIBS})
|
||||
|
||||
if(USE_ROCM_CK_SDPA)
|
||||
target_link_libraries(torch_hip PRIVATE ck_sdpa)
|
||||
endif()
|
||||
|
||||
if(USE_FBGEMM_GENAI)
|
||||
if(USE_ROCM)
|
||||
target_link_libraries(torch_hip PRIVATE fbgemm_genai)
|
||||
|
2
third_party/composable_kernel
vendored
2
third_party/composable_kernel
vendored
Submodule third_party/composable_kernel updated: de61e55493...7fe50dc3da
Reference in New Issue
Block a user