flash_attention integration (#81434)

# Summary:
- I added a new submodule Cutlass pointing to 2.10 release. The inclusion of flash_attention code should be gated by the flag: USE_FLASH_ATTENTION. This is defaulted to off resulting in flash to not be build anywhere. This is done on purpose since we don't have A100 machines to compile and test on.

- Only looked at CMake did not attempt bazel or buck yet.

-  I included the mha_fwd from flash_attention that has ben refactored to use cutlass 2.10. There is currently no backwards kernel on this branch. That would be a good follow up.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81434
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2022-09-09 20:11:26 +00:00
committed by PyTorch MergeBot
parent 219ff26172
commit 0fc02dbba4
29 changed files with 4395 additions and 5 deletions

View File

@ -77,6 +77,7 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_CUDNN : ${USE_CUDNN}")
message(STATUS " USE_EXPERIMENTAL_CUDNN_V8_API: ${USE_EXPERIMENTAL_CUDNN_V8_API}")
message(STATUS " CUDA version : ${CUDA_VERSION}")
message(STATUS " USE_FLASH_ATTENTION : ${USE_FLASH_ATTENTION}")
if(${USE_CUDNN})
message(STATUS " cuDNN version : ${CUDNN_VERSION}")
endif()