mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330)"
This reverts commit 62843c14bbf694f5722fd6e1075da4792507fe42. Reverted https://github.com/pytorch/pytorch/pull/162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see https://github.com/pytorch/pytorch/issues/162881 ([comment](https://github.com/pytorch/pytorch/pull/162330#issuecomment-3288544921))
This commit is contained in:
@ -874,7 +874,7 @@ cmake_dependent_option(
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention.\
|
||||
Will be disabled if not supported by the platform"
|
||||
ON
|
||||
"USE_CUDA OR USE_ROCM"
|
||||
"USE_CUDA OR USE_ROCM;NOT MSVC"
|
||||
OFF)
|
||||
|
||||
cmake_dependent_option(
|
||||
@ -909,7 +909,7 @@ cmake_dependent_option(
|
||||
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
||||
#
|
||||
if(USE_ROCM)
|
||||
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
|
||||
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
|
||||
include(cmake/External/aotriton.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
@ -95,72 +95,6 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
|
||||
namespace pytorch_flash
|
||||
{
|
||||
std::tuple<
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor>
|
||||
mha_fwd(
|
||||
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
||||
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
||||
std::optional<at::Tensor>&
|
||||
out_, // batch_size x seqlen_q x num_heads x head_size
|
||||
std::optional<at::Tensor>&
|
||||
alibi_slopes_, // num_heads or batch_size x num_heads
|
||||
const float p_dropout,
|
||||
const float softmax_scale,
|
||||
bool is_causal,
|
||||
std::optional<int64_t> window_size_left,
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
return mha_fwd_ck(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
return_softmax,
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
}
|
||||
#endif
|
||||
return mha_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace at {
|
||||
|
||||
namespace cuda::philox {
|
||||
|
@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
|
||||
#endif
|
||||
|
||||
TORCH_API
|
||||
std::tuple<
|
||||
inline std::tuple<
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
at::Tensor,
|
||||
@ -294,7 +294,42 @@ mha_fwd(
|
||||
std::optional<int64_t> window_size_right,
|
||||
const float softcap,
|
||||
const bool return_softmax,
|
||||
std::optional<at::Generator> gen_);
|
||||
std::optional<at::Generator> gen_) {
|
||||
#if defined(USE_ROCM_CK_SDPA)
|
||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||
at::ROCmFABackend::Ck) {
|
||||
const int non_null_window_left = window_size_left.value_or(-1);
|
||||
const int non_null_window_right = window_size_right.value_or(-1);
|
||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||
return mha_fwd_ck(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
non_null_window_left,
|
||||
non_null_window_right,
|
||||
return_softmax,
|
||||
gen_,
|
||||
dummy_attn_bias); // Not used in flash attention
|
||||
}
|
||||
#endif
|
||||
return mha_fwd_aot(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
out_,
|
||||
alibi_slopes_,
|
||||
p_dropout,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
window_size_left,
|
||||
window_size_right,
|
||||
return_softmax,
|
||||
gen_);
|
||||
}
|
||||
|
||||
inline std::tuple<
|
||||
at::Tensor,
|
||||
|
113
cmake/External/aotriton.cmake
vendored
113
cmake/External/aotriton.cmake
vendored
@ -45,88 +45,13 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
)
|
||||
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
|
||||
set(__AOTRITON_Z "gz")
|
||||
# Set the default __AOTRITON_LIB path
|
||||
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
|
||||
if(WIN32)
|
||||
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
|
||||
endif()
|
||||
|
||||
function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
|
||||
# Windows-specific dependencies - build these first
|
||||
if(NOT noimage)
|
||||
message(FATAL_ERROR "noimage must be ON for Windows builds")
|
||||
endif()
|
||||
# Build dlfcn-win32
|
||||
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
|
||||
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
|
||||
|
||||
ExternalProject_Add(${dlfcn-win32_external}
|
||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
GIT_TAG v1.4.2
|
||||
PREFIX ${__DLFCN_WIN32_PREFIX}
|
||||
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DCMAKE_C_COMPILER=cl
|
||||
-DCMAKE_CXX_COMPILER=cl
|
||||
-DBUILD_SHARED_LIBS=ON
|
||||
-DBUILD_TESTS=OFF
|
||||
BUILD_BYPRODUCTS
|
||||
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
|
||||
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
|
||||
)
|
||||
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
|
||||
"${__AOTRITON_INSTALL_DIR}/lib/"
|
||||
DEPENDEES install
|
||||
)
|
||||
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
|
||||
|
||||
# Build xz/liblzma
|
||||
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
|
||||
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
|
||||
|
||||
ExternalProject_Add(${xz_external}
|
||||
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
|
||||
GIT_TAG v5.8.1
|
||||
PREFIX ${__XZ_PREFIX}
|
||||
INSTALL_DIR ${__XZ_INSTALL_DIR}
|
||||
CMAKE_ARGS
|
||||
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
|
||||
-DCMAKE_BUILD_TYPE=Release
|
||||
-DBUILD_SHARED_LIBS=ON
|
||||
-DENABLE_NLS=OFF
|
||||
-DXZ_TOOL_LZMAINFO=OFF
|
||||
-DXZ_TOOL_XZ=OFF
|
||||
-DXZ_TOOL_XZDEC=OFF
|
||||
-DXZ_TOOL_LZMADEC=OFF
|
||||
BUILD_BYPRODUCTS
|
||||
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
|
||||
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
|
||||
)
|
||||
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
|
||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
||||
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
|
||||
"${__AOTRITON_INSTALL_DIR}/lib/"
|
||||
DEPENDEES install
|
||||
)
|
||||
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
|
||||
endfunction()
|
||||
|
||||
function(aotriton_build_from_source noimage project)
|
||||
if(noimage)
|
||||
SET(RECURSIVE "OFF")
|
||||
else()
|
||||
SET(RECURSIVE "ON")
|
||||
endif()
|
||||
if(WIN32)
|
||||
message(STATUS "Building AOTriton Windows dependencies")
|
||||
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
|
||||
endif()
|
||||
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
|
||||
|
||||
ExternalProject_Add(${project}
|
||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||
GIT_SUBMODULES_RECURSE ${RECURSIVE}
|
||||
@ -140,19 +65,12 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
-DAOTRITON_GPU_BUILD_TIMEOUT=0
|
||||
-DAOTRITON_NO_PYTHON=ON
|
||||
-DAOTRITON_NOIMAGE_MODE=${noimage}
|
||||
-DHIP_PLATFORM=amd
|
||||
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
|
||||
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
|
||||
BUILD_BYPRODUCTS
|
||||
"${__AOTRITON_LIB}"
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
USES_TERMINAL_DOWNLOAD TRUE
|
||||
USES_TERMINAL_CONFIGURE TRUE
|
||||
USES_TERMINAL_BUILD TRUE
|
||||
USES_TERMINAL_INSTALL TRUE
|
||||
)
|
||||
if(WIN32)
|
||||
add_dependencies(${project} dlfcn-win32_external xz_external)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||
@ -177,7 +95,7 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
|
||||
"${__AOTRITON_INSTALL_DIR}"
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_LIB}"
|
||||
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||
)
|
||||
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
|
||||
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
|
||||
@ -193,35 +111,14 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
string(CONCAT __AOTRITON_URL
|
||||
"${__AOTRITON_BASE_URL}"
|
||||
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
||||
|
||||
# Set up directories
|
||||
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
|
||||
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
|
||||
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
|
||||
set(__DOWNLOAD_NO_EXTRACT "")
|
||||
set(__BUILD_COMMANDS "")
|
||||
|
||||
# On Windows, we need custom tar extraction with UTF-8 support
|
||||
if(WIN32)
|
||||
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
|
||||
set(__BUILD_COMMANDS
|
||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
|
||||
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
|
||||
)
|
||||
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
|
||||
endif()
|
||||
|
||||
ExternalProject_Add(${project}
|
||||
URL "${__AOTRITON_URL}"
|
||||
URL_HASH SHA256=${__AOTRITON_SHA256}
|
||||
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
|
||||
${__DOWNLOAD_NO_EXTRACT}
|
||||
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
|
||||
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
${__BUILD_COMMANDS}
|
||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||
"${__AOTRITON_INSTALL_SOURCE_DIR}"
|
||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
|
||||
"${__AOTRITON_INSTALL_DIR}"
|
||||
BUILD_BYPRODUCTS
|
||||
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
|
||||
@ -267,7 +164,7 @@ if(NOT __AOTRITON_INCLUDED)
|
||||
endforeach()
|
||||
endforeach()
|
||||
endif()
|
||||
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB})
|
||||
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
|
||||
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
|
||||
set(AOTRITON_FOUND TRUE)
|
||||
endif() # __AOTRITON_INCLUDED
|
||||
|
@ -12,7 +12,6 @@ BU
|
||||
contiguities
|
||||
contiguity
|
||||
coo
|
||||
DEPENDEES
|
||||
deser
|
||||
din
|
||||
dout
|
||||
|
Reference in New Issue
Block a user