diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 1896530c0af6..16e4641ddf20 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) endif() if(USE_CUDA AND NOT USE_ROCM) + add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) + add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1) + add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) if($ENV{ATEN_STATIC_CUDA}) diff --git a/third_party/cutlass.BUILD b/third_party/cutlass.BUILD index e3e7b7b288e7..10100531d9be 100644 --- a/third_party/cutlass.BUILD +++ b/third_party/cutlass.BUILD @@ -13,6 +13,11 @@ cc_library( "tools/util/include/**/*.hpp", "tools/util/include/**/*.inl", ]), + defines = [ + "CUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", + ], includes = [ "include/", "tools/util/include/", diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 17b0e1688037..ab39e6660d76 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]: options = [ "-t=0", "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", + "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", + "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", "-w", f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", config.cuda.compile_opt_level,