mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Hopper (sm90
) support (#87736)
Essentially a followup of #87436 CC @xwang233 @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/87736 Approved by: https://github.com/xwang233, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
19d7941e37
commit
a7420d2ccb
@ -94,23 +94,28 @@ if(CUDA_VERSION VERSION_GREATER "10.5")
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_VERSION VERSION_LESS "11.1")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6")
|
||||
set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6")
|
||||
|
||||
if(CUDA_VERSION VERSION_LESS "11.8")
|
||||
set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(NOT CUDA_VERSION VERSION_LESS "11.8")
|
||||
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada")
|
||||
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0")
|
||||
|
||||
if(CUDA_VERSION VERSION_LESS "12.0")
|
||||
set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -248,6 +253,12 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
|
||||
elseif(${arch_name} STREQUAL "Ampere")
|
||||
set(arch_bin 8.0)
|
||||
set(arch_ptx 8.0)
|
||||
elseif(${arch_name} STREQUAL "Ada")
|
||||
set(arch_bin 8.9)
|
||||
set(arch_ptx 8.9)
|
||||
elseif(${arch_name} STREQUAL "Hopper")
|
||||
set(arch_bin 9.0)
|
||||
set(arch_ptx 9.0)
|
||||
else()
|
||||
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
|
||||
endif()
|
||||
|
@ -1730,10 +1730,11 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]:
|
||||
('Turing', '7.5+PTX'),
|
||||
('Ampere', '8.0;8.6+PTX'),
|
||||
('Ada', '8.9+PTX'),
|
||||
('Hopper', '9.0+PTX'),
|
||||
])
|
||||
|
||||
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
|
||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.9']
|
||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.9', '9.0']
|
||||
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
|
||||
|
||||
# The default is sm_30 for CUDA 9.x and 10.x
|
||||
|
Reference in New Issue
Block a user