mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[BE][CUDA][Bugfix]: Enable extended MMA shapes in CUTLASS. (#133686)
* This fixes a major CMake/Bazel configuration bug where we were leaving CUTLASS performance on the table, especially with FlashAttention. This now enables using MMA instructions on SM90+, which should close the gap between SDPA and the external FA2. Note these operations only affect H100 and newer GPUs. Thankfully, this seems to have been updated recently into being a noop on the CUTLASS side. Still better set the CMake variable properly. * Also enables additional new shape kernels added in the recent CUTLASS 3.5.1+ update. This was the original motivatin of the PR before I realized the basic MMA kernels were accidentally disabled since we didn't go through the submodule's CMake/Bazels. * Adds a bit to compile time and code size, but well worth it considering it speeds up our internal flash attention significantly on H100s at the cost of some minor additional compile time. * These kernels and settings will be needed for Flash Attention 3 whenever we add that too. Fixes #133695 Pull Request resolved: https://github.com/pytorch/pytorch/pull/133686 Approved by: https://github.com/ezyang
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							1d6e0412f5
						
					
				
				
					commit
					d3c2123ea6
				
			| @ -467,6 +467,9 @@ if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) | |||||||
| endif() | endif() | ||||||
|  |  | ||||||
| if(USE_CUDA AND NOT USE_ROCM) | if(USE_CUDA AND NOT USE_ROCM) | ||||||
|  |   add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) | ||||||
|  |   add_definitions(-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1) | ||||||
|  |   add_definitions(-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) | ||||||
|   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) |   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/include) | ||||||
|   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) |   list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/cutlass/tools/util/include) | ||||||
|   if($ENV{ATEN_STATIC_CUDA}) |   if($ENV{ATEN_STATIC_CUDA}) | ||||||
|  | |||||||
							
								
								
									
										5
									
								
								third_party/cutlass.BUILD
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								third_party/cutlass.BUILD
									
									
									
									
										vendored
									
									
								
							| @ -13,6 +13,11 @@ cc_library( | |||||||
|         "tools/util/include/**/*.hpp", |         "tools/util/include/**/*.hpp", | ||||||
|         "tools/util/include/**/*.inl", |         "tools/util/include/**/*.inl", | ||||||
|     ]), |     ]), | ||||||
|  |     defines = [ | ||||||
|  |         "CUTLASS_ENABLE_TENSOR_CORE_MMA=1", | ||||||
|  |         "CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", | ||||||
|  |         "CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", | ||||||
|  |     ], | ||||||
|     includes = [ |     includes = [ | ||||||
|         "include/", |         "include/", | ||||||
|         "tools/util/include/", |         "tools/util/include/", | ||||||
|  | |||||||
| @ -3063,6 +3063,8 @@ def _nvcc_compiler_options() -> List[str]: | |||||||
|     options = [ |     options = [ | ||||||
|         "-t=0", |         "-t=0", | ||||||
|         "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", |         "-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", | ||||||
|  |         "-DCUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES=1", | ||||||
|  |         "-DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED", | ||||||
|         "-w", |         "-w", | ||||||
|         f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", |         f"-gencode=arch=compute_{arch},code=[{','.join(code)}]", | ||||||
|         config.cuda.compile_opt_level, |         config.cuda.compile_opt_level, | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user