mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ROCm/Windows] Support aotriton for scaled_dot_product_attention on Windows. (#162330)
Enables flash attention and/or memory efficient attention on Windows with scaled_dot_product_attention via. aotriton. Already tested to be working on Windows with TheRock. Steps to enable: simply set `USE_FLASH_ATTENTION=1` and `USE_MEM_EFF_ATTENTION=1` as usual. See https://github.com/ROCm/TheRock/blob/main/external-builds/pytorch/build_prod_wheels.py#L578-L604 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162330 Approved by: https://github.com/jeffdaily Co-authored-by: Scott Todd <scott.todd0@gmail.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
5dc4e78047
commit
0826aafa04
@ -874,7 +874,7 @@ cmake_dependent_option(
|
||||
"Whether to build the flash_attention kernel for scaled dot product attention.\
|
||||
Will be disabled if not supported by the platform"
|
||||
ON
|
||||
"USE_CUDA OR USE_ROCM;NOT MSVC"
|
||||
"(USE_CUDA AND NOT MSVC) OR USE_ROCM"
|
||||
OFF)
|
||||
|
||||
cmake_dependent_option(
|
||||
@ -909,7 +909,7 @@ cmake_dependent_option(
|
||||
# USE_FLASH_ATTENTION -> USE_ROCM -> Dependencies.cmake -> aotriton.cmake
|
||||
#
|
||||
if(USE_ROCM)
|
||||
if(UNIX AND (USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION))
|
||||
if(USE_FLASH_ATTENTION OR USE_MEM_EFF_ATTENTION)
|
||||
include(cmake/External/aotriton.cmake)
|
||||
endif()
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user