From d3c2123ea61140985325a5f654d38b295eff4242 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Sat, 28 Sep 2024 21:11:15 +0000 Subject: [PATCH] [BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. (#133686) * This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly. * Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels. * Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time. * These kernels and settings will be needed for Flash Attention 3 whenever we add that too. Fixes #133695 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133686 Approved by: https://github.com/ezyang --- aten/src/ATen/CMakeLists.txt | 3 +++ third_party/cutlass.BUILD | 5 +++++ torch/_inductor/codecache.py | 2 ++ 3 files changed, 10 insertions(+) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 1896530c0af6..16e4641ddf20 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) + add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) + add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1) + add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e3e7b7b288e7..10100531d9be 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -13,6 +13,11 @@ cc_library( "tools/util/include/**/*.hpp", "tools/util/include/**/*.inl", ]), + defines = [ + "CUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + ], includes = [ "include/", "tools/util/include/", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 17b0e1688037..ab39e6660d76 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]: options = [ "-t=0", "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", config.cuda.compile_opt_level,