[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330)

Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton.
Already tested to be working on Windows with TheRock.

Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162330
Approved by: https://github.com/xinyazhang, https://github.com/ScottTodd, https://github.com/jeffdaily

Co-authored-by: Scott Todd <scott.todd0@gmail.com>
This commit is contained in:
Aaryaman Vasishta
2025-09-11 22:35:09 +00:00
committed by PyTorch MergeBot
parent 082d3dd9d5
commit 62843c14bb
5 changed files with 179 additions and 44 deletions

View File

@ -874,7 +874,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform"
ON
"USE_CUDA OR USE_ROCM;NOT MSVC"
"USE_CUDA OR USE_ROCM"
OFF)
cmake_dependent_option(
@ -909,7 +909,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
#
if(USE_ROCM)
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
include(cmake/External/aotriton.cmake)
endif()
endif()

View File

@ -95,6 +95,72 @@
#endif
#endif
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif
namespace at {
namespace cuda::philox {

View File

@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
#endif
TORCH_API
inline std::tuple<
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
@ -294,42 +294,7 @@ mha_fwd(
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
std::optional<at::Generator> gen_);
inline std::tuple<
at::Tensor,

View File

@ -45,13 +45,88 @@ if(NOT __AOTRITON_INCLUDED)
)
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz")
# Set the default __AOTRITON_LIB path
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
if(WIN32)
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
endif()
function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
# Windows-specific dependencies - build these first
if(NOT noimage)
message(FATAL_ERROR "noimage must be ON for Windows builds")
endif()
# Build dlfcn-win32
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
ExternalProject_Add(${dlfcn-win32_external}
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.2
PREFIX ${__DLFCN_WIN32_PREFIX}
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=cl
-DCMAKE_CXX_COMPILER=cl
-DBUILD_SHARED_LIBS=ON
-DBUILD_TESTS=OFF
BUILD_BYPRODUCTS
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
)
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
# Build xz/liblzma
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
ExternalProject_Add(${xz_external}
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
GIT_TAG v5.8.1
PREFIX ${__XZ_PREFIX}
INSTALL_DIR ${__XZ_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DBUILD_SHARED_LIBS=ON
-DENABLE_NLS=OFF
-DXZ_TOOL_LZMAINFO=OFF
-DXZ_TOOL_XZ=OFF
-DXZ_TOOL_XZDEC=OFF
-DXZ_TOOL_LZMADEC=OFF
BUILD_BYPRODUCTS
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
)
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
endfunction()
function(aotriton_build_from_source noimage project)
if(noimage)
SET(RECURSIVE "OFF")
else()
SET(RECURSIVE "ON")
endif()
if(WIN32)
message(STATUS "Building AOTriton Windows dependencies")
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
endif()
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
ExternalProject_Add(${project}
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_SUBMODULES_RECURSE ${RECURSIVE}
@ -65,12 +140,19 @@ if(NOT __AOTRITON_INCLUDED)
-DAOTRITON_GPU_BUILD_TIMEOUT=0
-DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NOIMAGE_MODE=${noimage}
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
-DHIP_PLATFORM=amd
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
BUILD_BYPRODUCTS
"${__AOTRITON_LIB}"
USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE
)
if(WIN32)
add_dependencies(${project} dlfcn-win32_external xz_external)
endif()
endfunction()
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
@ -95,7 +177,7 @@ if(NOT __AOTRITON_INCLUDED)
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
BUILD_BYPRODUCTS "${__AOTRITON_LIB}"
)
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
@ -111,14 +193,35 @@ if(NOT __AOTRITON_INCLUDED)
string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
set(__DOWNLOAD_NO_EXTRACT "")
set(__BUILD_COMMANDS "")
# On Windows, we need custom tar extraction with UTF-8 support
if(WIN32)
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
set(__BUILD_COMMANDS
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
)
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
endif()
ExternalProject_Add(${project}
URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256}
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
${__DOWNLOAD_NO_EXTRACT}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
${__BUILD_COMMANDS}
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
"${__AOTRITON_INSTALL_SOURCE_DIR}"
"${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
@ -164,7 +267,7 @@ if(NOT __AOTRITON_INCLUDED)
endforeach()
endforeach()
endif()
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB})
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
set(AOTRITON_FOUND TRUE)
endif() # __AOTRITON_INCLUDED

View File

@ -12,6 +12,7 @@ BU
contiguities
contiguity
coo
DEPENDEES
deser
din
dout