mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ROCm] Ck backend UX refactor (#152951)
Refactors how the enablement/disablement of CK Gemms and SDPA works. - Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms. - USE_ROCM_CK_GEMM is set to True by default on Linux - Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA. - USE_ROCM_CK_SDPA is set to False by default - (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release) - Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it. - the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default Pull Request resolved: https://github.com/pytorch/pytorch/pull/152951 Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus Co-authored-by: Jeff Daily <jeff.daily@amd.com> Co-authored-by: Jithun Nair <jithun.nair@amd.com> Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
da1f608ca3
commit
5f5f508aa8
@ -240,6 +240,8 @@ cmake_dependent_option(
|
||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
|
||||
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
|
||||
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)
|
||||
option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF)
|
||||
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
||||
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
|
@ -180,26 +180,27 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
|
||||
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
|
||||
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
|
||||
if(USE_FLASH_ATTENTION)
|
||||
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
|
||||
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
|
||||
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
|
||||
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
||||
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
|
||||
if(NUM_ARCHS GREATER 1)
|
||||
message(WARNING "Building CK for multiple archs can increase build time considerably!
|
||||
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
|
||||
message(STATUS "Generating CK kernel instances...")
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck)
|
||||
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
|
||||
# FAv3 Generation
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
|
||||
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
|
||||
if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
|
||||
message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead")
|
||||
caffe2_update_option(USE_ROCM_CK_SDPA ON)
|
||||
endif()
|
||||
if(USE_ROCM_CK_SDPA)
|
||||
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
|
||||
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
|
||||
if(NUM_ARCHS GREATER 1)
|
||||
message(WARNING "Building CK for multiple archs can increase build time considerably!
|
||||
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
|
||||
endif()
|
||||
endif()
|
||||
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
|
||||
message(STATUS "Generating CK kernel instances...")
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck)
|
||||
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_hip_ck_hip})
|
||||
# FAv3 Generation
|
||||
add_subdirectory(native/transformers/hip/flash_attn/ck/fav_v3)
|
||||
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
|
||||
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
|
||||
endif()
|
||||
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
|
||||
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
|
||||
@ -418,40 +419,42 @@ if(USE_CUDA)
|
||||
endif()
|
||||
|
||||
if(USE_ROCM)
|
||||
# NOTE: The PyTorch build does not actually add_subdirectory
|
||||
# third_party/composable_kernel or use it as a CMake library. What is used
|
||||
# is header only, so this should be ok, except that the CMake build generates
|
||||
# a ck/config.h. We just do that part here. Without this, the ck.h from the
|
||||
# ROCM SDK may get accidentally used instead.
|
||||
function(_pytorch_rocm_generate_ck_conf)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
set(CK_ENABLE_BF8 "ON")
|
||||
set(CK_USE_XDL "ON")
|
||||
set(CK_USE_WMMA "ON")
|
||||
configure_file(
|
||||
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
||||
)
|
||||
endfunction()
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
|
||||
# NOTE: The PyTorch build does not actually add_subdirectory
|
||||
# third_party/composable_kernel or use it as a CMake library. What is used
|
||||
# is header only, so this should be ok, except that the CMake build generates
|
||||
# a ck/config.h. We just do that part here. Without this, the ck.h from the
|
||||
# ROCM SDK may get accidentally used instead.
|
||||
function(_pytorch_rocm_generate_ck_conf)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
set(CK_ENABLE_BF8 "ON")
|
||||
set(CK_USE_XDL "ON")
|
||||
set(CK_USE_WMMA "ON")
|
||||
configure_file(
|
||||
"${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
||||
)
|
||||
endfunction()
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/example/ck_tile/01_fmha)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
|
||||
_pytorch_rocm_generate_ck_conf()
|
||||
endif()
|
||||
|
||||
# Next two lines are needed because TunableOp uses third-party/fmt
|
||||
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
|
||||
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
|
||||
if(USE_FLASH_ATTENTION)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
|
||||
endif()
|
||||
if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA)
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
|
||||
endif()
|
||||
list(APPEND ATen_HIP_SRCS
|
||||
${ATen_HIP_SRCS}
|
||||
${hip_hip}
|
||||
@ -461,12 +464,17 @@ endif()
|
||||
${native_quantized_hip_hip}
|
||||
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
||||
)
|
||||
if(WIN32) # Windows doesn't support Composable Kernels
|
||||
if(NOT USE_ROCM_CK_GEMM)
|
||||
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
||||
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
||||
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
||||
${native_hip_bgemm} ${native_hip_ck})
|
||||
endif()
|
||||
if(WIN32) # Windows doesn't support Composable Kernels and Triton
|
||||
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
||||
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
|
||||
endif()
|
||||
|
||||
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
||||
list(APPEND all_hip_cpp
|
||||
${native_nested_hip_cpp}
|
||||
|
@ -480,6 +480,9 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||
// call site for blasPreferredBackend(), we set it to an actual value.
|
||||
if (blas_preferred_backend == at::BlasBackend::Default) {
|
||||
blas_preferred_backend = at::BlasBackend::Cublas;
|
||||
// This logic sits in the getter because it needs to validate
|
||||
// values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT
|
||||
// which initialize the backend without calling the setter
|
||||
#ifdef USE_ROCM
|
||||
// AMD Instinct targets prefer hipblaslt
|
||||
static const bool hipblaslt_preferred = []() {
|
||||
@ -509,6 +512,10 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||
// hipblaslt support for all archs is not as complete as hipblas
|
||||
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
|
||||
static const bool hipblaslt_unsupported = []() {
|
||||
if(!hasCuBLASLt())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942",
|
||||
#if ROCM_VERSION >= 60300
|
||||
@ -534,6 +541,24 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||
return blas_preferred_backend;
|
||||
}
|
||||
|
||||
bool Context::ckSupported() {
|
||||
#ifdef USE_ROCM
|
||||
static const std::vector<std::string> supported_archs = {
|
||||
"gfx90a", "gfx942", "gfx950"
|
||||
};
|
||||
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
|
||||
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
||||
#ifdef _MSC_VER
|
||||
TORCH_WARN_ONCE(
|
||||
@ -543,8 +568,14 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
||||
#else
|
||||
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
|
||||
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
|
||||
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
|
||||
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
|
||||
#ifdef USE_ROCM
|
||||
static const bool ckSupportedFlag = ckSupported();
|
||||
static const bool hasCKGEMMFlag = hasCKGEMM();
|
||||
TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag),
|
||||
"Cannot set preferred blas backend to CK since following conditions are not true: ",
|
||||
"architecture supported for CK: ", ckSupportedFlag,
|
||||
", PyTorch built with CK GEMM support: ", hasCKGEMMFlag);
|
||||
#endif
|
||||
if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
|
||||
@ -556,35 +587,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
||||
#endif
|
||||
}
|
||||
|
||||
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
|
||||
at::ROCmFABackend Context::getROCmFAPreferredBackend() {
|
||||
#ifdef USE_ROCM
|
||||
// Set potential "Default" value so we don't have to interpret at call sites.
|
||||
// We use aotriton backend as the default, for now.
|
||||
if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) {
|
||||
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
|
||||
} else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) {
|
||||
// This logic sits in the getter because it needs to validate
|
||||
// values set via env vars such as TORCH_ROCM_FA_PREFER_CK
|
||||
// which initialize the backend without calling the setter
|
||||
// Perform validity checking
|
||||
static const bool hasCKSDPAFlag = hasCKSDPA();
|
||||
static const bool ckSupportedFlag = ckSupported();
|
||||
if(!(hasCKSDPAFlag && ckSupportedFlag)){
|
||||
TORCH_WARN_ONCE(
|
||||
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
|
||||
"architecture supported for CK: ", ckSupportedFlag,
|
||||
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
|
||||
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return rocm_fa_preferred_backend;
|
||||
}
|
||||
|
||||
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
|
||||
|
||||
// TODO: add plumbing for hasCK for validity checking
|
||||
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
|
||||
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
|
||||
#ifdef USE_ROCM
|
||||
if(b == at::ROCmFABackend::Ck) {
|
||||
static const bool ck_unsupported = []() {
|
||||
static const std::vector<std::string> archs = {
|
||||
"gfx90a", "gfx942"
|
||||
};
|
||||
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
|
||||
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
|
||||
TORCH_WARN_ONCE(
|
||||
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}();
|
||||
if(!ck_unsupported) rocm_fa_preferred_backend = b;
|
||||
}
|
||||
else {
|
||||
rocm_fa_preferred_backend = b;
|
||||
}
|
||||
static const bool hasCKSDPAFlag = hasCKSDPA();
|
||||
static const bool ckSupportedFlag = ckSupported();
|
||||
TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag),
|
||||
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
|
||||
"architecture supported for CK: ", ckSupportedFlag,
|
||||
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
|
||||
#endif
|
||||
rocm_fa_preferred_backend = b;
|
||||
}
|
||||
|
@ -132,6 +132,7 @@ class TORCH_API Context {
|
||||
static bool hasKleidiAI();
|
||||
static bool hasLAPACK();
|
||||
static bool hasMKLDNN();
|
||||
static bool ckSupported();
|
||||
static bool hasMAGMA() {
|
||||
return detail::getCUDAHooks().hasMAGMA();
|
||||
}
|
||||
@ -162,6 +163,12 @@ class TORCH_API Context {
|
||||
static bool hasROCM() {
|
||||
return detail::getCUDAHooks().hasROCM();
|
||||
}
|
||||
static bool hasCKSDPA() {
|
||||
return detail::getCUDAHooks().hasCKSDPA();
|
||||
}
|
||||
static bool hasCKGEMM() {
|
||||
return detail::getCUDAHooks().hasCKGEMM();
|
||||
}
|
||||
static bool hasHIP() {
|
||||
return detail::getHIPHooks().hasHIP();
|
||||
}
|
||||
@ -252,7 +259,7 @@ class TORCH_API Context {
|
||||
at::BlasBackend blasPreferredBackend();
|
||||
void setBlasPreferredBackend(at::BlasBackend);
|
||||
|
||||
at::ROCmFABackend getROCmFAPreferredBackend() const;
|
||||
at::ROCmFABackend getROCmFAPreferredBackend();
|
||||
void setROCmFAPreferredBackend(at::ROCmFABackend);
|
||||
|
||||
// Note [Enabling Deterministic Operations]
|
||||
|
@ -832,7 +832,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
@ -1273,7 +1273,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
#endif
|
||||
}
|
||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
@ -1289,7 +1289,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
@ -1341,7 +1341,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
@ -1357,7 +1357,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
#if defined(USE_ROCM) && !defined(_MSC_VER)
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
|
@ -207,6 +207,27 @@ bool CUDAHooks::hasCuBLASLt() const {
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
bool CUDAHooks::hasCKSDPA() const {
|
||||
#if !defined(USE_ROCM)
|
||||
return false;
|
||||
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasCKGEMM() const {
|
||||
#if !defined(USE_ROCM)
|
||||
return false;
|
||||
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasROCM() const {
|
||||
// Currently, this is same as `compiledWithMIOpen`.
|
||||
// But in future if there are ROCm builds without MIOpen,
|
||||
|
@ -31,6 +31,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
||||
bool hasCuSOLVER() const override;
|
||||
bool hasCuBLASLt() const override;
|
||||
bool hasROCM() const override;
|
||||
bool hasCKSDPA() const override;
|
||||
bool hasCKGEMM() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
DeviceIndex current_device() const override;
|
||||
bool isBuilt() const override {return true;}
|
||||
|
@ -118,6 +118,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool hasCKSDPA() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool hasCKGEMM() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual const at::cuda::NVRTC& nvrtc() const {
|
||||
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
|
||||
}
|
||||
|
||||
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
|
||||
template <>
|
||||
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
|
||||
template <>
|
||||
@ -18,7 +19,7 @@ template <>
|
||||
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,6 +1,7 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
|
||||
@ -781,3 +782,4 @@ void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
#endif // USE_ROCM_CK_GEMM
|
||||
|
@ -1,6 +1,7 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
|
||||
@ -484,3 +485,4 @@ void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
#endif // USE_ROCM_CK_GEMM
|
||||
|
@ -1,6 +1,7 @@
|
||||
#undef __HIP_NO_HALF_CONVERSIONS__
|
||||
|
||||
#include <ATen/native/hip/ck_gemm.h>
|
||||
#if defined(USE_ROCM_CK_GEMM)
|
||||
#include <ATen/native/hip/ck_gemm_template.h>
|
||||
|
||||
#include <ck/utility/sequence.hpp>
|
||||
@ -606,3 +607,4 @@ void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
|
||||
}
|
||||
|
||||
} // namespace at::native
|
||||
#endif // USE_ROCM_CK_GEMM
|
||||
|
@ -1346,7 +1346,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
|
||||
if(at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
std::optional<Tensor> out(res);
|
||||
std::optional<Tensor> seqused_k = std::nullopt;
|
||||
std::optional<Tensor> alibi_slopes = std::nullopt;
|
||||
|
@ -431,7 +431,7 @@ _efficient_attention_backward(
|
||||
// ROCM Implementation
|
||||
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
|
||||
{
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
|
||||
// Store grad_bias in optional
|
||||
std::optional<at::Tensor> opt_grad_bias = grad_bias;
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
namespace pytorch_flash {
|
||||
std::tuple<
|
||||
at::Tensor, // dQ
|
||||
@ -117,4 +117,4 @@ mem_eff_backward_ck(
|
||||
}
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
#endif // USE_ROCM_CK_SDPA
|
||||
|
@ -3,7 +3,7 @@
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
namespace pytorch_flash {
|
||||
|
||||
std::tuple<
|
||||
@ -64,4 +64,4 @@ mem_eff_backward_ck(
|
||||
const at::Tensor philox_offset);
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
#endif // USE_ROCM_CK_SDPA
|
||||
|
@ -1,7 +1,7 @@
|
||||
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
|
||||
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
namespace pytorch_flash {
|
||||
std::tuple<
|
||||
at::Tensor, // output
|
||||
@ -93,4 +93,4 @@ mem_eff_forward_ck(
|
||||
}
|
||||
|
||||
} // namespace pytorch_flash
|
||||
#endif // USE_CK_FLASH_ATTENTION
|
||||
#endif // USE_ROCM_CK_SDPA
|
||||
|
@ -147,7 +147,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
|
||||
const at::Tensor& philox_seed,
|
||||
const at::Tensor& philox_offset);
|
||||
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
// CK implementation
|
||||
TORCH_API
|
||||
std::tuple<
|
||||
@ -295,7 +295,7 @@ mha_fwd(
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
@ -368,7 +368,7 @@ mha_varlen_fwd(
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
@ -441,9 +441,10 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
@ -474,10 +475,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
|
||||
philox_offset);
|
||||
// for FA return [dQ, dV, dK, dSoftmax]
|
||||
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
|
||||
#else
|
||||
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
return mha_bwd_aot(
|
||||
dout,
|
||||
q,
|
||||
@ -530,7 +529,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
|
||||
const bool deterministic,
|
||||
const at::Tensor philox_seed,
|
||||
const at::Tensor philox_offset) {
|
||||
#if defined(USE_CK_FLASH_ATTENTION)
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
std::optional<at::Tensor> non_null_dbias = std::nullopt;
|
||||
|
@ -1446,8 +1446,8 @@ if(USE_ROCM)
|
||||
if(USE_MEM_EFF_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION)
|
||||
endif()
|
||||
if(USE_CK_FLASH_ATTENTION)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION)
|
||||
if(USE_ROCM_CK_SDPA)
|
||||
target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
@ -1045,6 +1045,9 @@ if(USE_ROCM)
|
||||
if(HIPBLASLT_VEC_EXT)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
|
||||
endif()
|
||||
if(USE_ROCM_CK_GEMM)
|
||||
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM)
|
||||
endif()
|
||||
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
|
||||
if(WIN32)
|
||||
add_definitions(-DROCM_ON_WINDOWS)
|
||||
|
@ -127,10 +127,11 @@ function(caffe2_print_configuration_summary)
|
||||
endif()
|
||||
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
||||
if(${USE_ROCM})
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}")
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
|
||||
message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}")
|
||||
message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}")
|
||||
endif()
|
||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
|
@ -179,3 +179,30 @@ by recompiling the PyTorch from source.
|
||||
Please add below line as an argument to cmake command parameters::
|
||||
|
||||
-DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON
|
||||
|
||||
Enabling/Disabling ROCm Composable Kernel
|
||||
-----------------------------------------
|
||||
|
||||
Enabling composable_kernel (CK) for both SDPA and GEMMs is a two-part process. First the user must have built
|
||||
pytorch while setting the corresponding environment variable to '1'
|
||||
|
||||
SDPA:
|
||||
``USE_ROCM_CK_SDPA=1``
|
||||
|
||||
GEMMs:
|
||||
``USE_ROCM_CK_GEMM=1``
|
||||
|
||||
Second, the user must explicitly request that CK be used as the backend library via the corresponding python
|
||||
call
|
||||
|
||||
SDPA:
|
||||
``setROCmFAPreferredBackend('<choice>')``
|
||||
|
||||
GEMMs:
|
||||
``setBlasPreferredBackend('<choice>')``
|
||||
|
||||
To enable CK in either scenario, simply pass 'ck' to those functions.
|
||||
|
||||
In order to set the backend to CK, the user MUST have built with the correct environment variable. If not,
|
||||
PyTorch will print a warning and use the "default" backend. For GEMMs, this will route to hipblas and
|
||||
for SDPA it routes to aotriton.
|
||||
|
6
setup.py
6
setup.py
@ -156,6 +156,12 @@
|
||||
# USE_ROCM_KERNEL_ASSERT=1
|
||||
# Enable kernel assert in ROCm platform
|
||||
#
|
||||
# USE_ROCM_CK_GEMM=1
|
||||
# Enable building CK GEMM backend in ROCm platform
|
||||
#
|
||||
# USE_ROCM_CK_SDPA=1
|
||||
# Enable building CK SDPA backend in ROCm platform
|
||||
#
|
||||
# Environment variables we respect (these environment variables are
|
||||
# conventional and are often understood/set by other software.)
|
||||
#
|
||||
|
Reference in New Issue
Block a user