[BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. (#133686)

* This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly.
*  Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels.
* Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time.
* These kernels and settings will be needed for Flash Attention 3 whenever we add that too.

Fixes #133695

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133686
Approved by: https://github.com/ezyang
This commit is contained in:
Aaron Gokaslan
2024-09-28 21:11:15 +00:00
committed by PyTorch MergeBot
parent 1d6e0412f5
commit d3c2123ea6
3 changed files with 10 additions and 0 deletions

View File

@ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
endif()
if(USE_CUDA AND NOT USE_ROCM)
add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1)
add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
if($ENV{ATEN_STATIC_CUDA})

View File

@ -13,6 +13,11 @@ cc_library(
"tools/util/include/**/*.hpp",
"tools/util/include/**/*.inl",
]),
defines = [
"CUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
"CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
],
includes = [
"include/",
"tools/util/include/",

View File

@ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]:
options = [
"-t=0",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
"-w",
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
config.cuda.compile_opt_level,