mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330)"
This reverts commit 62843c14bbf694f5722fd6e1075da4792507fe42. Reverted https://github.com/pytorch/pytorch/pull/162330 on behalf of https://github.com/atalman due to Sorry reverting looks like broke windows nightlies see https://github.com/pytorch/pytorch/issues/162881 ([comment](https://github.com/pytorch/pytorch/pull/162330#issuecomment-3288544921))
This commit is contained in:
@ -874,7 +874,7 @@ cmake_dependent_option(
|
|||||||
"Whether to build the flash_attention kernel for scaled dot product attention.\
|
"Whether to build the flash_attention kernel for scaled dot product attention.\
|
||||||
Will be disabled if not supported by the platform"
|
Will be disabled if not supported by the platform"
|
||||||
ON
|
ON
|
||||||
"USE_CUDA OR USE_ROCM"
|
"USE_CUDA OR USE_ROCM;NOT MSVC"
|
||||||
OFF)
|
OFF)
|
||||||
|
|
||||||
cmake_dependent_option(
|
cmake_dependent_option(
|
||||||
@ -909,7 +909,7 @@ cmake_dependent_option(
|
|||||||
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
||||||
#
|
#
|
||||||
if(USE_ROCM)
|
if(USE_ROCM)
|
||||||
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
|
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
|
||||||
include(cmake/External/aotriton.cmake)
|
include(cmake/External/aotriton.cmake)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
@ -95,72 +95,6 @@
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(USE_ROCM) && (defined(USE_FLASH_ATTENTION) || defined(USE_MEM_EFF_ATTENTION))
|
|
||||||
namespace pytorch_flash
|
|
||||||
{
|
|
||||||
std::tuple<
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor,
|
|
||||||
at::Tensor>
|
|
||||||
mha_fwd(
|
|
||||||
const at::Tensor& q, // batch_size x seqlen_q x num_heads x head_size
|
|
||||||
const at::Tensor& k, // batch_size x seqlen_k x num_heads_k x head_size
|
|
||||||
const at::Tensor& v, // batch_size x seqlen_k x num_heads_k x head_size
|
|
||||||
std::optional<at::Tensor>&
|
|
||||||
out_, // batch_size x seqlen_q x num_heads x head_size
|
|
||||||
std::optional<at::Tensor>&
|
|
||||||
alibi_slopes_, // num_heads or batch_size x num_heads
|
|
||||||
const float p_dropout,
|
|
||||||
const float softmax_scale,
|
|
||||||
bool is_causal,
|
|
||||||
std::optional<int64_t> window_size_left,
|
|
||||||
std::optional<int64_t> window_size_right,
|
|
||||||
const float softcap,
|
|
||||||
const bool return_softmax,
|
|
||||||
std::optional<at::Generator> gen_) {
|
|
||||||
#if defined(USE_ROCM_CK_SDPA)
|
|
||||||
if (at::globalContext().getROCmFAPreferredBackend() ==
|
|
||||||
at::ROCmFABackend::Ck) {
|
|
||||||
const int non_null_window_left = window_size_left.value_or(-1);
|
|
||||||
const int non_null_window_right = window_size_right.value_or(-1);
|
|
||||||
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
|
||||||
return mha_fwd_ck(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out_,
|
|
||||||
p_dropout,
|
|
||||||
softmax_scale,
|
|
||||||
is_causal,
|
|
||||||
non_null_window_left,
|
|
||||||
non_null_window_right,
|
|
||||||
return_softmax,
|
|
||||||
gen_,
|
|
||||||
dummy_attn_bias); // Not used in flash attention
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
return mha_fwd_aot(
|
|
||||||
q,
|
|
||||||
k,
|
|
||||||
v,
|
|
||||||
out_,
|
|
||||||
alibi_slopes_,
|
|
||||||
p_dropout,
|
|
||||||
softmax_scale,
|
|
||||||
is_causal,
|
|
||||||
window_size_left,
|
|
||||||
window_size_right,
|
|
||||||
return_softmax,
|
|
||||||
gen_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
namespace cuda::philox {
|
namespace cuda::philox {
|
||||||
|
@ -270,7 +270,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor> mha_varle
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
TORCH_API
|
TORCH_API
|
||||||
std::tuple<
|
inline std::tuple<
|
||||||
at::Tensor,
|
at::Tensor,
|
||||||
at::Tensor,
|
at::Tensor,
|
||||||
at::Tensor,
|
at::Tensor,
|
||||||
@ -294,7 +294,42 @@ mha_fwd(
|
|||||||
std::optional<int64_t> window_size_right,
|
std::optional<int64_t> window_size_right,
|
||||||
const float softcap,
|
const float softcap,
|
||||||
const bool return_softmax,
|
const bool return_softmax,
|
||||||
std::optional<at::Generator> gen_);
|
std::optional<at::Generator> gen_) {
|
||||||
|
#if defined(USE_ROCM_CK_SDPA)
|
||||||
|
if (at::globalContext().getROCmFAPreferredBackend() ==
|
||||||
|
at::ROCmFABackend::Ck) {
|
||||||
|
const int non_null_window_left = window_size_left.value_or(-1);
|
||||||
|
const int non_null_window_right = window_size_right.value_or(-1);
|
||||||
|
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
|
||||||
|
return mha_fwd_ck(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
non_null_window_left,
|
||||||
|
non_null_window_right,
|
||||||
|
return_softmax,
|
||||||
|
gen_,
|
||||||
|
dummy_attn_bias); // Not used in flash attention
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return mha_fwd_aot(
|
||||||
|
q,
|
||||||
|
k,
|
||||||
|
v,
|
||||||
|
out_,
|
||||||
|
alibi_slopes_,
|
||||||
|
p_dropout,
|
||||||
|
softmax_scale,
|
||||||
|
is_causal,
|
||||||
|
window_size_left,
|
||||||
|
window_size_right,
|
||||||
|
return_softmax,
|
||||||
|
gen_);
|
||||||
|
}
|
||||||
|
|
||||||
inline std::tuple<
|
inline std::tuple<
|
||||||
at::Tensor,
|
at::Tensor,
|
||||||
|
113
cmake/External/aotriton.cmake
vendored
113
cmake/External/aotriton.cmake
vendored
@ -45,88 +45,13 @@ if(NOT __AOTRITON_INCLUDED)
|
|||||||
)
|
)
|
||||||
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
|
set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore
|
||||||
set(__AOTRITON_Z "gz")
|
set(__AOTRITON_Z "gz")
|
||||||
# Set the default __AOTRITON_LIB path
|
|
||||||
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so")
|
|
||||||
if(WIN32)
|
|
||||||
set(__AOTRITON_LIB "${__AOTRITON_INSTALL_DIR}/lib/aotriton_v2.lib")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
function(aotriton_build_windows_dependencies dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
|
|
||||||
# Windows-specific dependencies - build these first
|
|
||||||
if(NOT noimage)
|
|
||||||
message(FATAL_ERROR "noimage must be ON for Windows builds")
|
|
||||||
endif()
|
|
||||||
# Build dlfcn-win32
|
|
||||||
set(__DLFCN_WIN32_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32")
|
|
||||||
set(__DLFCN_WIN32_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/dlfcn-win32-install")
|
|
||||||
|
|
||||||
ExternalProject_Add(${dlfcn-win32_external}
|
|
||||||
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
|
||||||
GIT_TAG v1.4.2
|
|
||||||
PREFIX ${__DLFCN_WIN32_PREFIX}
|
|
||||||
INSTALL_DIR ${__DLFCN_WIN32_INSTALL_DIR}
|
|
||||||
CMAKE_ARGS
|
|
||||||
-DCMAKE_INSTALL_PREFIX=${__DLFCN_WIN32_INSTALL_DIR}
|
|
||||||
-DCMAKE_BUILD_TYPE=Release
|
|
||||||
-DCMAKE_C_COMPILER=cl
|
|
||||||
-DCMAKE_CXX_COMPILER=cl
|
|
||||||
-DBUILD_SHARED_LIBS=ON
|
|
||||||
-DBUILD_TESTS=OFF
|
|
||||||
BUILD_BYPRODUCTS
|
|
||||||
"${__DLFCN_WIN32_INSTALL_DIR}/lib/dl.lib"
|
|
||||||
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
|
|
||||||
)
|
|
||||||
ExternalProject_Add_Step(${dlfcn-win32_external} copy_to_aotriton
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
|
||||||
"${__DLFCN_WIN32_INSTALL_DIR}/bin/dl.dll"
|
|
||||||
"${__AOTRITON_INSTALL_DIR}/lib/"
|
|
||||||
DEPENDEES install
|
|
||||||
)
|
|
||||||
set(${dlfcn-win32_DIR} "${__DLFCN_WIN32_INSTALL_DIR}/share/dlfcn-win32" CACHE PATH "Path to dlfcn-win32 CMake config" FORCE)
|
|
||||||
|
|
||||||
# Build xz/liblzma
|
|
||||||
set(__XZ_PREFIX "${CMAKE_CURRENT_BINARY_DIR}/xz")
|
|
||||||
set(__XZ_INSTALL_DIR "${CMAKE_CURRENT_BINARY_DIR}/xz-install")
|
|
||||||
|
|
||||||
ExternalProject_Add(${xz_external}
|
|
||||||
GIT_REPOSITORY https://github.com/tukaani-project/xz.git
|
|
||||||
GIT_TAG v5.8.1
|
|
||||||
PREFIX ${__XZ_PREFIX}
|
|
||||||
INSTALL_DIR ${__XZ_INSTALL_DIR}
|
|
||||||
CMAKE_ARGS
|
|
||||||
-DCMAKE_INSTALL_PREFIX=${__XZ_INSTALL_DIR}
|
|
||||||
-DCMAKE_BUILD_TYPE=Release
|
|
||||||
-DBUILD_SHARED_LIBS=ON
|
|
||||||
-DENABLE_NLS=OFF
|
|
||||||
-DXZ_TOOL_LZMAINFO=OFF
|
|
||||||
-DXZ_TOOL_XZ=OFF
|
|
||||||
-DXZ_TOOL_XZDEC=OFF
|
|
||||||
-DXZ_TOOL_LZMADEC=OFF
|
|
||||||
BUILD_BYPRODUCTS
|
|
||||||
"${__XZ_INSTALL_DIR}/lib/lzma.lib"
|
|
||||||
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
|
|
||||||
)
|
|
||||||
ExternalProject_Add_Step(${xz_external} copy_to_aotriton
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E copy_if_different
|
|
||||||
"${__XZ_INSTALL_DIR}/bin/liblzma.dll"
|
|
||||||
"${__AOTRITON_INSTALL_DIR}/lib/"
|
|
||||||
DEPENDEES install
|
|
||||||
)
|
|
||||||
set(${liblzma_DIR} "${__XZ_INSTALL_DIR}/lib/cmake/liblzma" CACHE PATH "Path to xz/liblzma CMake config" FORCE)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(aotriton_build_from_source noimage project)
|
function(aotriton_build_from_source noimage project)
|
||||||
if(noimage)
|
if(noimage)
|
||||||
SET(RECURSIVE "OFF")
|
SET(RECURSIVE "OFF")
|
||||||
else()
|
else()
|
||||||
SET(RECURSIVE "ON")
|
SET(RECURSIVE "ON")
|
||||||
endif()
|
endif()
|
||||||
if(WIN32)
|
|
||||||
message(STATUS "Building AOTriton Windows dependencies")
|
|
||||||
aotriton_build_windows_dependencies(dlfcn-win32_external xz_external dlfcn-win32_DIR liblzma_DIR)
|
|
||||||
endif()
|
|
||||||
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
|
message(STATUS "PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH}")
|
||||||
|
|
||||||
ExternalProject_Add(${project}
|
ExternalProject_Add(${project}
|
||||||
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
GIT_REPOSITORY https://github.com/ROCm/aotriton.git
|
||||||
GIT_SUBMODULES_RECURSE ${RECURSIVE}
|
GIT_SUBMODULES_RECURSE ${RECURSIVE}
|
||||||
@ -140,19 +65,12 @@ if(NOT __AOTRITON_INCLUDED)
|
|||||||
-DAOTRITON_GPU_BUILD_TIMEOUT=0
|
-DAOTRITON_GPU_BUILD_TIMEOUT=0
|
||||||
-DAOTRITON_NO_PYTHON=ON
|
-DAOTRITON_NO_PYTHON=ON
|
||||||
-DAOTRITON_NOIMAGE_MODE=${noimage}
|
-DAOTRITON_NOIMAGE_MODE=${noimage}
|
||||||
-DHIP_PLATFORM=amd
|
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||||
$<$<BOOL:${WIN32}>:-Ddlfcn-win32_DIR=${dlfcn-win32_DIR}>
|
|
||||||
$<$<BOOL:${WIN32}>:-Dliblzma_DIR=${liblzma_DIR}>
|
|
||||||
BUILD_BYPRODUCTS
|
|
||||||
"${__AOTRITON_LIB}"
|
|
||||||
USES_TERMINAL_DOWNLOAD TRUE
|
USES_TERMINAL_DOWNLOAD TRUE
|
||||||
USES_TERMINAL_CONFIGURE TRUE
|
USES_TERMINAL_CONFIGURE TRUE
|
||||||
USES_TERMINAL_BUILD TRUE
|
USES_TERMINAL_BUILD TRUE
|
||||||
USES_TERMINAL_INSTALL TRUE
|
USES_TERMINAL_INSTALL TRUE
|
||||||
)
|
)
|
||||||
if(WIN32)
|
|
||||||
add_dependencies(${project} dlfcn-win32_external xz_external)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
endfunction()
|
||||||
|
|
||||||
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
|
set(__AOTRITON_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
|
||||||
@ -177,7 +95,7 @@ if(NOT __AOTRITON_INCLUDED)
|
|||||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||||
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
|
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_runtime"
|
||||||
"${__AOTRITON_INSTALL_DIR}"
|
"${__AOTRITON_INSTALL_DIR}"
|
||||||
BUILD_BYPRODUCTS "${__AOTRITON_LIB}"
|
BUILD_BYPRODUCTS "${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so"
|
||||||
)
|
)
|
||||||
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
|
message(STATUS "Using AOTriton Runtime from pre-compiled binary ${__AOTRITON_URL}.\
|
||||||
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
|
Set env variables AOTRITON_INSTALL_FROM_SOURCE=1 to build from source.")
|
||||||
@ -193,35 +111,14 @@ if(NOT __AOTRITON_INCLUDED)
|
|||||||
string(CONCAT __AOTRITON_URL
|
string(CONCAT __AOTRITON_URL
|
||||||
"${__AOTRITON_BASE_URL}"
|
"${__AOTRITON_BASE_URL}"
|
||||||
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
"${__AOTRITON_VER}/${__AOTRITON_FILE}")
|
||||||
|
|
||||||
# Set up directories
|
|
||||||
set(__AOTRITON_DOWNLOAD_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_download-${image})
|
|
||||||
set(__AOTRITON_EXTRACT_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image})
|
|
||||||
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR})
|
|
||||||
set(__DOWNLOAD_NO_EXTRACT "")
|
|
||||||
set(__BUILD_COMMANDS "")
|
|
||||||
|
|
||||||
# On Windows, we need custom tar extraction with UTF-8 support
|
|
||||||
if(WIN32)
|
|
||||||
set(__DOWNLOAD_NO_EXTRACT "DOWNLOAD_NO_EXTRACT;TRUE")
|
|
||||||
set(__BUILD_COMMANDS
|
|
||||||
COMMAND ${CMAKE_COMMAND} -E make_directory "${__AOTRITON_EXTRACT_DIR}"
|
|
||||||
COMMAND tar --options hdrcharset=UTF-8 -xf "${__AOTRITON_DOWNLOAD_DIR}/${__AOTRITON_FILE}" -C "${__AOTRITON_EXTRACT_DIR}"
|
|
||||||
)
|
|
||||||
set(__AOTRITON_INSTALL_SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}/aotriton)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
ExternalProject_Add(${project}
|
ExternalProject_Add(${project}
|
||||||
URL "${__AOTRITON_URL}"
|
URL "${__AOTRITON_URL}"
|
||||||
URL_HASH SHA256=${__AOTRITON_SHA256}
|
URL_HASH SHA256=${__AOTRITON_SHA256}
|
||||||
DOWNLOAD_DIR ${__AOTRITON_DOWNLOAD_DIR}
|
SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}
|
||||||
${__DOWNLOAD_NO_EXTRACT}
|
|
||||||
SOURCE_DIR ${__AOTRITON_EXTRACT_DIR}
|
|
||||||
CONFIGURE_COMMAND ""
|
CONFIGURE_COMMAND ""
|
||||||
BUILD_COMMAND ""
|
BUILD_COMMAND ""
|
||||||
${__BUILD_COMMANDS}
|
|
||||||
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
INSTALL_COMMAND ${CMAKE_COMMAND} -E copy_directory
|
||||||
"${__AOTRITON_INSTALL_SOURCE_DIR}"
|
"${CMAKE_CURRENT_BINARY_DIR}/aotriton_image-${image}"
|
||||||
"${__AOTRITON_INSTALL_DIR}"
|
"${__AOTRITON_INSTALL_DIR}"
|
||||||
BUILD_BYPRODUCTS
|
BUILD_BYPRODUCTS
|
||||||
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
|
"${__AOTRITON_INSTALL_DIR}/lib/aotriton.images/${image}/__signature__"
|
||||||
@ -267,7 +164,7 @@ if(NOT __AOTRITON_INCLUDED)
|
|||||||
endforeach()
|
endforeach()
|
||||||
endforeach()
|
endforeach()
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_LIB})
|
target_link_libraries(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/lib/libaotriton_v2.so)
|
||||||
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
|
target_include_directories(__caffe2_aotriton INTERFACE ${__AOTRITON_INSTALL_DIR}/include)
|
||||||
set(AOTRITON_FOUND TRUE)
|
set(AOTRITON_FOUND TRUE)
|
||||||
endif() # __AOTRITON_INCLUDED
|
endif() # __AOTRITON_INCLUDED
|
||||||
|
@ -12,7 +12,6 @@ BU
|
|||||||
contiguities
|
contiguities
|
||||||
contiguity
|
contiguity
|
||||||
coo
|
coo
|
||||||
DEPENDEES
|
|
||||||
deser
|
deser
|
||||||
din
|
din
|
||||||
dout
|
dout
|
||||||
|
Reference in New Issue
Block a user