mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
flash_attention integration (#81434)
# Summary: - I added a new submodule Cutlass pointing to 2.10 release. The inclusion of flash_attention code should be gated by the flag: USE_FLASH_ATTENTION. This is defaulted to off resulting in flash to not be build anywhere. This is done on purpose since we don't have A100 machines to compile and test on. - Only looked at CMake did not attempt bazel or buck yet. - I included the mha_fwd from flash_attention that has ben refactored to use cutlass 2.10. There is currently no backwards kernel on this branch. That would be a good follow up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81434 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
219ff26172
commit
0fc02dbba4
2
setup.py
2
setup.py
@ -322,7 +322,7 @@ def get_submodule_folders():
|
||||
git_modules_path = os.path.join(cwd, ".gitmodules")
|
||||
default_modules_path = [os.path.join(third_party_path, name) for name in [
|
||||
"gloo", "cpuinfo", "tbb", "onnx",
|
||||
"foxi", "QNNPACK", "fbgemm"
|
||||
"foxi", "QNNPACK", "fbgemm", "cutlass"
|
||||
]]
|
||||
if not os.path.exists(git_modules_path):
|
||||
return default_modules_path
|
||||
|
Reference in New Issue
Block a user