mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Initial Flash Attention support on ROCM (#114309)
This pull requests add initial Flash Attention support for AMD/ROCM platform. It added a specialized Triton repository/branch as a compile-time dependency for Flash Attention math library on AMD/ROCM. This triton submodule is not used at runtime and will not be shipped to the final pytorch package. We have the plan to release this specialized Triton as a separate project. Know limitations: - [ ] Only supports MI200 series GPU (i.e., `gcnArchName == gfx90a:sramecc+:xnack-`. - [ ] Only supports power of two sequence lengths. - [ ] No support for varlen APIs. - [ ] Only support head dimension 16,32,64,128. - [ ] Performance is still being optimized. Fixes https://github.com/pytorch/pytorch/issues/112997 Pull Request resolved: https://github.com/pytorch/pytorch/pull/114309 Approved by: https://github.com/jeffdaily, https://github.com/malfet --------- Co-authored-by: Joseph Groenenboom <joseph.groenenboom@amd.com>
This commit is contained in:
@ -117,6 +117,7 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_ROCM : ${USE_ROCM}")
|
||||
if(${USE_ROCM})
|
||||
message(STATUS " ROCM_VERSION : ${ROCM_VERSION}")
|
||||
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
|
||||
endif()
|
||||
message(STATUS " BUILD_NVFUSER : ${BUILD_NVFUSER}")
|
||||
message(STATUS " USE_EIGEN_FOR_BLAS : ${CAFFE2_USE_EIGEN_FOR_BLAS}")
|
||||
|
Reference in New Issue
Block a user