Revert "[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330)"

This reverts commit 62843c14bbf694f5722fd6e1075da4792507fe42.

Reverted https://github.com/pytorch/pytorch/pull/162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see https://github.com/pytorch/pytorch/issues/162881 ([comment](https://github.com/pytorch/pytorch/pull/162330#issuecomment-3288544921))
This commit is contained in:
PyTorch MergeBot
2025-09-13 15:43:50 +00:00
parent deb7ebe0a3
commit 5b9114bf19
5 changed files with 44 additions and 179 deletions

View File

@ -874,7 +874,7 @@ cmake_dependent_option(
"Whether to build the flash_attention kernel for scaled dot product attention.\ "Whether to build the flash_attention kernel for scaled dot product attention.\
Will be disabled if not supported by the platform" Will be disabled if not supported by the platform"
ON ON
"USE_CUDA OR USE_ROCM" "USE_CUDA OR USE_ROCM;NOT MSVC"
OFF) OFF)
cmake_dependent_option( cmake_dependent_option(
@ -909,7 +909,7 @@ cmake_dependent_option(
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake # USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
# #
if(USE_ROCM) if(USE_ROCM)
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION) if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
include(cmake/External/aotriton.cmake) include(cmake/External/aotriton.cmake)
endif() endif()
endif() endif()

View File

@ -95,72 +95,6 @@
#endif #endif
#endif #endif
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
namespace pytorch_flash
{
std::tuple<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor>
mha_fwd(
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
std::optional<at::Tensor>&
out_, // batch_size x seqlen_q x num_heads x head_size
std::optional<at::Tensor>&
alibi_slopes_, // num_heads or batch_size x num_heads
const float p_dropout,
const float softmax_scale,
bool is_causal,
std::optional<int64_t> window_size_left,
std::optional<int64_t> window_size_right,
const float softcap,
const bool return_softmax,
std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
}
#endif
namespace at { namespace at {
namespace cuda::philox { namespace cuda::philox {

View File

@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
#endif #endif
TORCH_API TORCH_API
std::tuple< inline std::tuple<
at::Tensor, at::Tensor,
at::Tensor, at::Tensor,
at::Tensor, at::Tensor,
@ -294,7 +294,42 @@ mha_fwd(
std::optional<int64_t> window_size_right, std::optional<int64_t> window_size_right,
const float softcap, const float softcap,
const bool return_softmax, const bool return_softmax,
std::optional<at::Generator> gen_); std::optional<at::Generator> gen_) {
#if defined(USE_ROCM_CK_SDPA)
if (at::globalContext().getROCmFAPreferredBackend() ==
at::ROCmFABackend::Ck) {
const int non_null_window_left = window_size_left.value_or(-1);
const int non_null_window_right = window_size_right.value_or(-1);
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
return mha_fwd_ck(
q,
k,
v,
out_,
p_dropout,
softmax_scale,
is_causal,
non_null_window_left,
non_null_window_right,
return_softmax,
gen_,
dummy_attn_bias); // Not used in flash attention
}
#endif
return mha_fwd_aot(
q,
k,
v,
out_,
alibi_slopes_,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
return_softmax,
gen_);
}
inline std::tuple< inline std::tuple<
at::Tensor, at::Tensor,

View File

@ -45,88 +45,13 @@ if(NOT __AOTRITON_INCLUDED)
) )
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
set(__AOTRITON_Z "gz") set(__AOTRITON_Z "gz")
# Set the default __AOTRITON_LIB path
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
if(WIN32)
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
endif()
function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
# Windows-specific dependencies - build these first
if(NOT noimage)
message(FATAL_ERROR "noimage must be ON for Windows builds")
endif()
# Build dlfcn-win32
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
ExternalProject_Add(${dlfcn-win32_external}
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
GIT_TAG v1.4.2
PREFIX ${__DLFCN_WIN32_PREFIX}
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DCMAKE_C_COMPILER=cl
-DCMAKE_CXX_COMPILER=cl
-DBUILD_SHARED_LIBS=ON
-DBUILD_TESTS=OFF
BUILD_BYPRODUCTS
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
)
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
# Build xz/liblzma
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
ExternalProject_Add(${xz_external}
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
GIT_TAG v5.8.1
PREFIX ${__XZ_PREFIX}
INSTALL_DIR ${__XZ_INSTALL_DIR}
CMAKE_ARGS
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=Release
-DBUILD_SHARED_LIBS=ON
-DENABLE_NLS=OFF
-DXZ_TOOL_LZMAINFO=OFF
-DXZ_TOOL_XZ=OFF
-DXZ_TOOL_XZDEC=OFF
-DXZ_TOOL_LZMADEC=OFF
BUILD_BYPRODUCTS
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
)
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
"${__AOTRITON_INSTALL_DIR}/lib/"
DEPENDEES install
)
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
endfunction()
function(aotriton_build_from_source noimage project) function(aotriton_build_from_source noimage project)
if(noimage) if(noimage)
SET(RECURSIVE "OFF") SET(RECURSIVE "OFF")
else() else()
SET(RECURSIVE "ON") SET(RECURSIVE "ON")
endif() endif()
if(WIN32)
message(STATUS "Building AOTriton Windows dependencies")
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
endif()
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}") message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
ExternalProject_Add(${project} ExternalProject_Add(${project}
GIT_REPOSITORY https://github.com/ROCm/aotriton.git GIT_REPOSITORY https://github.com/ROCm/aotriton.git
GIT_SUBMODULES_RECURSE ${RECURSIVE} GIT_SUBMODULES_RECURSE ${RECURSIVE}
@ -140,19 +65,12 @@ if(NOT __AOTRITON_INCLUDED)
-DAOTRITON_GPU_BUILD_TIMEOUT=0 -DAOTRITON_GPU_BUILD_TIMEOUT=0
-DAOTRITON_NO_PYTHON=ON -DAOTRITON_NO_PYTHON=ON
-DAOTRITON_NOIMAGE_MODE=${noimage} -DAOTRITON_NOIMAGE_MODE=${noimage}
-DHIP_PLATFORM=amd BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
BUILD_BYPRODUCTS
"${__AOTRITON_LIB}"
USES_TERMINAL_DOWNLOAD TRUE USES_TERMINAL_DOWNLOAD TRUE
USES_TERMINAL_CONFIGURE TRUE USES_TERMINAL_CONFIGURE TRUE
USES_TERMINAL_BUILD TRUE USES_TERMINAL_BUILD TRUE
USES_TERMINAL_INSTALL TRUE USES_TERMINAL_INSTALL TRUE
) )
if(WIN32)
add_dependencies(${project} dlfcn-win32_external xz_external)
endif()
endfunction() endfunction()
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
@ -177,7 +95,7 @@ if(NOT __AOTRITON_INCLUDED)
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime" "${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
"${__AOTRITON_INSTALL_DIR}" "${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS "${__AOTRITON_LIB}" BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
) )
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\ message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.") Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
@ -193,35 +111,14 @@ if(NOT __AOTRITON_INCLUDED)
string(CONCAT __AOTRITON_URL string(CONCAT __AOTRITON_URL
"${__AOTRITON_BASE_URL}" "${__AOTRITON_BASE_URL}"
"${__AOTRITON_VER}/${__AOTRITON_FILE}") "${__AOTRITON_VER}/${__AOTRITON_FILE}")
# Set up directories
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
set(__DOWNLOAD_NO_EXTRACT "")
set(__BUILD_COMMANDS "")
# On Windows, we need custom tar extraction with UTF-8 support
if(WIN32)
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
set(__BUILD_COMMANDS
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
)
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
endif()
ExternalProject_Add(${project} ExternalProject_Add(${project}
URL "${__AOTRITON_URL}" URL "${__AOTRITON_URL}"
URL_HASH SHA256=${__AOTRITON_SHA256} URL_HASH SHA256=${__AOTRITON_SHA256}
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
${__DOWNLOAD_NO_EXTRACT}
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
${__BUILD_COMMANDS}
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
"${__AOTRITON_INSTALL_SOURCE_DIR}" "${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
"${__AOTRITON_INSTALL_DIR}" "${__AOTRITON_INSTALL_DIR}"
BUILD_BYPRODUCTS BUILD_BYPRODUCTS
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__" "${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
@ -267,7 +164,7 @@ if(NOT __AOTRITON_INCLUDED)
endforeach() endforeach()
endforeach() endforeach()
endif() endif()
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB}) target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include) target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
set(AOTRITON_FOUND TRUE) set(AOTRITON_FOUND TRUE)
endif() # __AOTRITON_INCLUDED endif() # __AOTRITON_INCLUDED

View File

@ -12,7 +12,6 @@ BU
contiguities contiguities
contiguity contiguity
coo coo
DEPENDEES
deser deser
din din
dout dout