[ROCm] Ck backend UX refactor (#152951)

Refactors how the enablement/disablement of CK Gemms and SDPA works.

- Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms.
- USE_ROCM_CK_GEMM is set to True by default on Linux
- Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA.
- USE_ROCM_CK_SDPA is set to False by default
- (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release)
- Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it.
- the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152951
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
Andres Lugo
2025-08-08 18:40:17 +00:00
committed by PyTorch MergeBot
parent da1f608ca3
commit 5f5f508aa8
23 changed files with 232 additions and 105 deletions

View File

@ -240,6 +240,8 @@ cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)
option(USE_ROCM_CK_SDPA "Use ROCm Composable Kernel for SDPA" OFF)
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF

View File

@ -180,9 +180,11 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if(DEFINED ENV{USE_CK_FLASH_ATTENTION})
set(USE_CK_FLASH_ATTENTION $ENV{USE_CK_FLASH_ATTENTION})
if(USE_CK_FLASH_ATTENTION STREQUAL "1")
if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
message(STATUS "USE_CK_FLASH_ATTENTION is being deprecated. Please use USE_ROCM_CK_SDPA instead")
caffe2_update_option(USE_ROCM_CK_SDPA ON)
endif()
if(USE_ROCM_CK_SDPA)
if(DEFINED ENV{PYTORCH_ROCM_ARCH})
list(LENGTH PYTORCH_ROCM_ARCH NUM_ARCHS)
if(NUM_ARCHS GREATER 1)
@ -190,7 +192,7 @@ if(USE_FLASH_ATTENTION)
Consider setting PYTORCH_ROCM_ARCH env var value as the gfx arch you need to build for")
endif()
endif()
message(STATUS "USE_CK_FLASH_ATTENTION is set; building PyTorch with CK Flash Attention enabled")
message(STATUS "USE_ROCM_CK_SDPA is set; building PyTorch with CK SDPA enabled")
message(STATUS "Generating CK kernel instances...")
add_subdirectory(native/transformers/hip/flash_attn/ck)
file(GLOB flash_attention_hip_ck_hip "native/transformers/hip/flash_attn/ck/*.hip")
@ -200,7 +202,6 @@ if(USE_FLASH_ATTENTION)
file(GLOB flash_attention_v3_hip "native/transformers/hip/flash_attn/ck/fav_v3/*.hip")
list(APPEND native_transformers_hip_hip ${flash_attention_v3_hip})
endif()
endif()
file(GLOB flash_attention_hip_aot_hip "native/transformers/hip/flash_attn/aot/*.hip")
file(GLOB flash_attention_src_hip_hip "native/transformers/hip/flash_attn/src/*.hip")
endif()
@ -418,6 +419,7 @@ if(USE_CUDA)
endif()
if(USE_ROCM)
if((USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA) OR USE_ROCM_CK_GEMM)
# NOTE: The PyTorch build does not actually add_subdirectory
# third_party/composable_kernel or use it as a CMake library. What is used
# is header only, so this should be ok, except that the CMake build generates
@ -445,13 +447,14 @@ if(USE_ROCM)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/aiter/csrc/include)
_pytorch_rocm_generate_ck_conf()
endif()
# Next two lines are needed because TunableOp uses third-party/fmt
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
if(USE_FLASH_ATTENTION)
if(USE_FLASH_ATTENTION AND USE_ROCM_CK_SDPA)
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/native/transformers/hip/flash_attn/ck)
endif()
endif()
list(APPEND ATen_HIP_SRCS
${ATen_HIP_SRCS}
${hip_hip}
@ -461,12 +464,17 @@ endif()
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
if(WIN32) # Windows doesn't support Composable Kernels
if(NOT USE_ROCM_CK_GEMM)
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_hip_bgemm} ${native_hip_ck})
endif()
if(WIN32) # Windows doesn't support Composable Kernels and Triton
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
endif()
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
list(APPEND all_hip_cpp
${native_nested_hip_cpp}

View File

@ -480,6 +480,9 @@ at::BlasBackend Context::blasPreferredBackend() {
// call site for blasPreferredBackend(), we set it to an actual value.
if (blas_preferred_backend == at::BlasBackend::Default) {
blas_preferred_backend = at::BlasBackend::Cublas;
// This logic sits in the getter because it needs to validate
// values set via env vars such as TORCH_BLAS_PREFER_CUBLASLT
// which initialize the backend without calling the setter
#ifdef USE_ROCM
// AMD Instinct targets prefer hipblaslt
static const bool hipblaslt_preferred = []() {
@ -509,6 +512,10 @@ at::BlasBackend Context::blasPreferredBackend() {
// hipblaslt support for all archs is not as complete as hipblas
if (blas_preferred_backend == at::BlasBackend::Cublaslt) {
static const bool hipblaslt_unsupported = []() {
if(!hasCuBLASLt())
{
return true;
}
static const std::vector<std::string> archs = {
"gfx90a", "gfx942",
#if ROCM_VERSION >= 60300
@ -534,6 +541,24 @@ at::BlasBackend Context::blasPreferredBackend() {
return blas_preferred_backend;
}
bool Context::ckSupported() {
#ifdef USE_ROCM
static const std::vector<std::string> supported_archs = {
"gfx90a", "gfx942", "gfx950"
};
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return false;
}
}
return true;
#else
return false;
#endif
}
void Context::setBlasPreferredBackend(at::BlasBackend b) {
#ifdef _MSC_VER
TORCH_WARN_ONCE(
@ -543,8 +568,14 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#else
TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(),
"Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt.");
TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(),
"Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm.");
#ifdef USE_ROCM
static const bool ckSupportedFlag = ckSupported();
static const bool hasCKGEMMFlag = hasCKGEMM();
TORCH_CHECK((b != at::BlasBackend::Ck) || (ckSupportedFlag && hasCKGEMMFlag),
"Cannot set preferred blas backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK GEMM support: ", hasCKGEMMFlag);
#endif
if (b != at::BlasBackend::Default && b != at::BlasBackend::Cublas) {
TORCH_WARN_ONCE(
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
@ -556,35 +587,40 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) {
#endif
}
at::ROCmFABackend Context::getROCmFAPreferredBackend() const {
at::ROCmFABackend Context::getROCmFAPreferredBackend() {
#ifdef USE_ROCM
// Set potential "Default" value so we don't have to interpret at call sites.
// We use aotriton backend as the default, for now.
if(rocm_fa_preferred_backend == at::ROCmFABackend::Default) {
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
} else if (rocm_fa_preferred_backend == at::ROCmFABackend::Ck) {
// This logic sits in the getter because it needs to validate
// values set via env vars such as TORCH_ROCM_FA_PREFER_CK
// which initialize the backend without calling the setter
// Perform validity checking
static const bool hasCKSDPAFlag = hasCKSDPA();
static const bool ckSupportedFlag = ckSupported();
if(!(hasCKSDPAFlag && ckSupportedFlag)){
TORCH_WARN_ONCE(
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
rocm_fa_preferred_backend = at::ROCmFABackend::AOTriton;
}
}
#endif
return rocm_fa_preferred_backend;
}
void Context::setROCmFAPreferredBackend(at::ROCmFABackend b) {
// TODO: add plumbing for hasCK for validity checking
TORCH_CHECK((b != at::ROCmFABackend::Ck) || hasROCM(),
"Cannot set preferred flash attention backend to Ck if PyTorch has not been compiled for ROCm.");
#ifdef USE_ROCM
if(b == at::ROCmFABackend::Ck) {
static const bool ck_unsupported = []() {
static const std::vector<std::string> archs = {
"gfx90a", "gfx942"
};
for (auto index: c10::irange(detail::getCUDAHooks().deviceCount())) {
if (!detail::getCUDAHooks().isGPUArch(archs, index)) {
TORCH_WARN_ONCE(
"Attempting to use CK on an unsupported architecture! Cannot set backend to CK");
return true;
}
}
return false;
}();
if(!ck_unsupported) rocm_fa_preferred_backend = b;
}
else {
rocm_fa_preferred_backend = b;
}
static const bool hasCKSDPAFlag = hasCKSDPA();
static const bool ckSupportedFlag = ckSupported();
TORCH_CHECK((b != at::ROCmFABackend::Ck) || (hasCKSDPAFlag && ckSupportedFlag),
"Cannot set preferred SDPA backend to CK since following conditions are not true: ",
"architecture supported for CK: ", ckSupportedFlag,
", PyTorch built with CK SDPA support: ", hasCKSDPAFlag);
#endif
rocm_fa_preferred_backend = b;
}

View File

@ -132,6 +132,7 @@ class TORCH_API Context {
static bool hasKleidiAI();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool ckSupported();
static bool hasMAGMA() {
return detail::getCUDAHooks().hasMAGMA();
}
@ -162,6 +163,12 @@ class TORCH_API Context {
static bool hasROCM() {
return detail::getCUDAHooks().hasROCM();
}
static bool hasCKSDPA() {
return detail::getCUDAHooks().hasCKSDPA();
}
static bool hasCKGEMM() {
return detail::getCUDAHooks().hasCKGEMM();
}
static bool hasHIP() {
return detail::getHIPHooks().hasHIP();
}
@ -252,7 +259,7 @@ class TORCH_API Context {
at::BlasBackend blasPreferredBackend();
void setBlasPreferredBackend(at::BlasBackend);
at::ROCmFABackend getROCmFAPreferredBackend() const;
at::ROCmFABackend getROCmFAPreferredBackend();
void setROCmFAPreferredBackend(at::ROCmFABackend);
// Note [Enabling Deterministic Operations]

View File

@ -832,7 +832,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
}
#if defined(USE_ROCM) && !defined(_MSC_VER)
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
@ -1273,7 +1273,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
#if defined(USE_ROCM) && !defined(_MSC_VER)
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
@ -1289,7 +1289,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
#if defined(USE_ROCM) && !defined(_MSC_VER)
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
@ -1341,7 +1341,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
#if defined(USE_ROCM) && !defined(_MSC_VER)
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@ -1357,7 +1357,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
#if defined(USE_ROCM) && !defined(_MSC_VER)
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}

View File

@ -207,6 +207,27 @@ bool CUDAHooks::hasCuBLASLt() const {
#endif
}
bool CUDAHooks::hasCKSDPA() const {
#if !defined(USE_ROCM)
return false;
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_SDPA)
return true;
#else
return false;
#endif
}
bool CUDAHooks::hasCKGEMM() const {
#if !defined(USE_ROCM)
return false;
#elif defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
return true;
#else
return false;
#endif
}
bool CUDAHooks::hasROCM() const {
// Currently, this is same as `compiledWithMIOpen`.
// But in future if there are ROCm builds without MIOpen,

View File

@ -31,6 +31,8 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasCuSOLVER() const override;
bool hasCuBLASLt() const override;
bool hasROCM() const override;
bool hasCKSDPA() const override;
bool hasCKGEMM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool isBuilt() const override {return true;}

View File

@ -118,6 +118,14 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface {
return false;
}
virtual bool hasCKSDPA() const {
return false;
}
virtual bool hasCKGEMM() const {
return false;
}
virtual const at::cuda::NVRTC& nvrtc() const {
TORCH_CHECK(false, "NVRTC requires CUDA. ", CUDA_HELP);
}

View File

@ -10,6 +10,7 @@ inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented");
}
#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
template <>
void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double));
template <>
@ -18,7 +19,7 @@ template <>
void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
#endif
} // namespace at::native

View File

@ -1,6 +1,7 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -781,3 +782,4 @@ void gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -1,6 +1,7 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -484,3 +485,4 @@ void gemm_internal_ck<double>(CUDABLAS_GEMM_ARGTYPES(double)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -1,6 +1,7 @@
#undef __HIP_NO_HALF_CONVERSIONS__
#include <ATen/native/hip/ck_gemm.h>
#if defined(USE_ROCM_CK_GEMM)
#include <ATen/native/hip/ck_gemm_template.h>
#include <ck/utility/sequence.hpp>
@ -606,3 +607,4 @@ void gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half)) {
}
} // namespace at::native
#endif // USE_ROCM_CK_GEMM

View File

@ -1346,7 +1346,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, c10::SymInt, c10::SymInt> _efficient_
if(at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
std::optional<Tensor> out(res);
std::optional<Tensor> seqused_k = std::nullopt;
std::optional<Tensor> alibi_slopes = std::nullopt;

View File

@ -431,7 +431,7 @@ _efficient_attention_backward(
// ROCM Implementation
if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck)
{
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
const auto my_softmax_scale = sdp::calculate_scale(query, scale).expect_float();
// Store grad_bias in optional
std::optional<at::Tensor> opt_grad_bias = grad_bias;

View File

@ -1,7 +1,7 @@
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
namespace pytorch_flash {
std::tuple<
at::Tensor, // dQ
@ -117,4 +117,4 @@ mem_eff_backward_ck(
}
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION
#endif // USE_ROCM_CK_SDPA

View File

@ -3,7 +3,7 @@
#include <ATen/core/Tensor.h>
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
namespace pytorch_flash {
std::tuple<
@ -64,4 +64,4 @@ mem_eff_backward_ck(
const at::Tensor philox_offset);
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION
#endif // USE_ROCM_CK_SDPA

View File

@ -1,7 +1,7 @@
#include <ATen/native/transformers/hip/flash_attn/flash_common_hip.hpp>
#include <ATen/native/transformers/hip/flash_attn/ck/me_ck_api.h>
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
namespace pytorch_flash {
std::tuple<
at::Tensor, // output
@ -93,4 +93,4 @@ mem_eff_forward_ck(
}
} // namespace pytorch_flash
#endif // USE_CK_FLASH_ATTENTION
#endif // USE_ROCM_CK_SDPA

View File

@ -147,7 +147,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd_aot(
const at::Tensor& philox_seed,
const at::Tensor& philox_offset);
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
// CK implementation
TORCH_API
std::tuple<
@ -295,7 +295,7 @@ mha_fwd(
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
@ -368,7 +368,7 @@ mha_varlen_fwd(
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
@ -441,9 +441,10 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
#if defined(USE_CK_FLASH_ATTENTION)
std::optional<at::Tensor> non_null_dbias = std::nullopt;
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
@ -474,10 +475,8 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_bwd(
philox_offset);
// for FA return [dQ, dV, dK, dSoftmax]
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
#else
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
#endif
}
#endif
return mha_bwd_aot(
dout,
q,
@ -530,7 +529,7 @@ inline std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varlen_bwd
const bool deterministic,
const at::Tensor philox_seed,
const at::Tensor philox_offset) {
#if defined(USE_CK_FLASH_ATTENTION)
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
std::optional<at::Tensor> non_null_dbias = std::nullopt;

View File

@ -1446,8 +1446,8 @@ if(USE_ROCM)
if(USE_MEM_EFF_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION)
endif()
if(USE_CK_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION)
if(USE_ROCM_CK_SDPA)
target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA)
endif()
endif()

View File

@ -1045,6 +1045,9 @@ if(USE_ROCM)
if(HIPBLASLT_VEC_EXT)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
endif()
if(USE_ROCM_CK_GEMM)
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM)
endif()
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
if(WIN32)
add_definitions(-DROCM_ON_WINDOWS)

View File

@ -129,8 +129,9 @@ function(caffe2_print_configuration_summary)
if(${USE_ROCM})
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
message(STATUS " USE_CK_FLASH_ATTENTION : ${USE_CK_FLASH_ATTENTION}")
message(STATUS " USE_MEM_EFF_ATTENTION : ${USE_MEM_EFF_ATTENTION}")
message(STATUS " USE_ROCM_CK_SDPA : ${USE_ROCM_CK_SDPA}")
message(STATUS " USE_ROCM_CK_GEMM : ${USE_ROCM_CK_GEMM}")
endif()
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")

View File

@ -179,3 +179,30 @@ by recompiling the PyTorch from source.
Please add below line as an argument to cmake command parameters::
-DROCM_FORCE_ENABLE_GPU_ASSERTS:BOOL=ON
Enabling/Disabling ROCm Composable Kernel
-----------------------------------------
Enabling composable_kernel (CK) for both SDPA and GEMMs is a two-part process. First the user must have built
pytorch while setting the corresponding environment variable to '1'
SDPA:
``USE_ROCM_CK_SDPA=1``
GEMMs:
``USE_ROCM_CK_GEMM=1``
Second, the user must explicitly request that CK be used as the backend library via the corresponding python
call
SDPA:
``setROCmFAPreferredBackend('<choice>')``
GEMMs:
``setBlasPreferredBackend('<choice>')``
To enable CK in either scenario, simply pass 'ck' to those functions.
In order to set the backend to CK, the user MUST have built with the correct environment variable. If not,
PyTorch will print a warning and use the "default" backend. For GEMMs, this will route to hipblas and
for SDPA it routes to aotriton.

View File

@ -156,6 +156,12 @@
# USE_ROCM_KERNEL_ASSERT=1
# Enable kernel assert in ROCm platform
#
# USE_ROCM_CK_GEMM=1
# Enable building CK GEMM backend in ROCm platform
#
# USE_ROCM_CK_SDPA=1
# Enable building CK SDPA backend in ROCm platform
#
# Environment variables we respect (these environment variables are
# conventional and are often understood/set by other software.)
#