[ROCm] Ck backend UX refactor (#152951)

Refactors how the enablement/disablement of CK Gemms and SDPA works.

- Adds USE_ROCM_CK_GEMM compile flag for enabling CK gemms.
- USE_ROCM_CK_GEMM is set to True by default on Linux
- Updates USE_CK_FLASH_ATTENTION to USE_ROCM_CK_SDPA.
- USE_ROCM_CK_SDPA is set to False by default
- (USE_CK_FLASH_ATTENTION still works for now, but will be deprecated in a future release)
- Prevents these CK libraries from being used unless pytorch has been built specifically with the functionality AND is running on a system architecture that supports it.
- the getters for these library backends will also do some validity checking in case the user used an environment variable to change the backend. If invalid, (i.e. one of the cases mentioned above is false) the backend will be set as the current non-CK default

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152951
Approved by: https://github.com/eqy, https://github.com/jeffdaily, https://github.com/m-gallus

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Co-authored-by: Jithun Nair <jithun.nair@amd.com>
Co-authored-by: Jane (Yuan) Xu <31798555+janeyx99@users.noreply.github.com>
This commit is contained in:
Andres Lugo
2025-08-08 18:40:17 +00:00
committed by PyTorch MergeBot
parent da1f608ca3
commit 5f5f508aa8
23 changed files with 232 additions and 105 deletions

View File

@ -1446,8 +1446,8 @@ if(USE_ROCM)
if(USE_MEM_EFF_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_MEM_EFF_ATTENTION)
endif()
if(USE_CK_FLASH_ATTENTION)
target_compile_definitions(torch_hip PRIVATE USE_CK_FLASH_ATTENTION)
if(USE_ROCM_CK_SDPA)
target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA)
endif()
endif()