mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. (#133686)
* This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly. * Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels. * Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time. * These kernels and settings will be needed for Flash Attention 3 whenever we add that too. Fixes #133695 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133686 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d6e0412f5
commit
d3c2123ea6
@ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
|||||||
endif()
|
endif()
|
||||||
|
|
||||||
if(USE_CUDA AND NOT USE_ROCM)
|
if(USE_CUDA AND NOT USE_ROCM)
|
||||||
|
add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
|
||||||
|
add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1)
|
||||||
|
add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED)
|
||||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include)
|
||||||
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include)
|
||||||
if($ENV{ATEN_STATIC_CUDA})
|
if($ENV{ATEN_STATIC_CUDA})
|
||||||
|
5
third_party/cutlass.BUILD
vendored
5
third_party/cutlass.BUILD
vendored
@ -13,6 +13,11 @@ cc_library(
|
|||||||
"tools/util/include/**/*.hpp",
|
"tools/util/include/**/*.hpp",
|
||||||
"tools/util/include/**/*.inl",
|
"tools/util/include/**/*.inl",
|
||||||
]),
|
]),
|
||||||
|
defines = [
|
||||||
|
"CUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||||
|
"CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
|
||||||
|
"CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
||||||
|
],
|
||||||
includes = [
|
includes = [
|
||||||
"include/",
|
"include/",
|
||||||
"tools/util/include/",
|
"tools/util/include/",
|
||||||
|
@ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]:
|
|||||||
options = [
|
options = [
|
||||||
"-t=0",
|
"-t=0",
|
||||||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
|
||||||
|
"-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1",
|
||||||
|
"-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED",
|
||||||
"-w",
|
"-w",
|
||||||
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
|
f"-gencode=arch=compute_{arch},code=[{','.join(code)}]",
|
||||||
config.cuda.compile_opt_level,
|
config.cuda.compile_opt_level,
|
||||||
|
Reference in New Issue
Block a user