mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[ATen][CUDA] CUTLASS matmuls: add sm_103a flag (#162956)
This PR adds an `sm_103a` flag for GroupMM and RowwiseScaledMM. Contrary to just #161399, this simply adds the flag as the support for `sm_103a` matmuls is going to be added to CUTLASS v4.2 (see https://github.com/pytorch/pytorch/pull/161399#issuecomment-3252892937). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162956 Approved by: https://github.com/eqy, https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
e3783a9575
commit
6926710adf
@ -107,6 +107,12 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_100a,code=sm_100a")
|
||||
endif()
|
||||
endif()
|
||||
# We will need to gate against CUDA version, because sm_103a is available on CUDA 12.9+
|
||||
if("${_arch}" STREQUAL "103a" AND CUDA_VERSION VERSION_GREATER_EQUAL 12.9)
|
||||
if(_existing_arch_flags MATCHES ".*compute_100.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_103a,code=sm_103a")
|
||||
endif()
|
||||
endif()
|
||||
if("${_arch}" STREQUAL "120a")
|
||||
if(_existing_arch_flags MATCHES ".*compute_120.*")
|
||||
list(APPEND _file_compile_flags "-gencode;arch=compute_120a,code=sm_120a")
|
||||
@ -120,13 +126,13 @@ if(INTERN_BUILD_ATEN_OPS)
|
||||
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/RowwiseScaledMM.cu"
|
||||
"89;90a;100a;120a")
|
||||
"89;90a;100a;103a;120a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/ScaledGroupMM.cu"
|
||||
"90a")
|
||||
_BUILD_FOR_ADDITIONAL_ARCHS(
|
||||
"${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/cuda/GroupMM.cu"
|
||||
"90a;100a")
|
||||
"90a;100a;103a")
|
||||
|
||||
endif()
|
||||
|
||||
|
Reference in New Issue
Block a user