mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Build] Allow shipping PTX on a per-file basis (#18155)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs)
|
||||
"${multiValueArgs}" ${ARGN} )
|
||||
|
||||
foreach(_ARCH ${arg_CUDA_ARCHS})
|
||||
string(REPLACE "." "" _ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_ARCH}"
|
||||
CODE "sm_${_ARCH}")
|
||||
# handle +PTX suffix: generate both sm and ptx codes if requested
|
||||
string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
|
||||
if(NOT _HAS_PTX EQUAL -1)
|
||||
string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "compute_${_STRIPPED_ARCH}")
|
||||
else()
|
||||
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
|
||||
set_gencode_flag_for_srcs(
|
||||
SRCS ${arg_SRCS}
|
||||
ARCH "compute_${_STRIPPED_ARCH}"
|
||||
CODE "sm_${_STRIPPED_ARCH}")
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
if (${arg_BUILD_PTX_FOR_ARCH})
|
||||
@ -251,7 +266,10 @@ endmacro()
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes.
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
# architecture in `SRC_CUDA_ARCHS`.
|
||||
# The loose intersection is defined as:
|
||||
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
|
||||
# where `<=` is the version comparison operator.
|
||||
@ -268,44 +286,63 @@ endmacro()
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
|
||||
#
|
||||
# Example With PTX:
|
||||
# SRC_CUDA_ARCHS="8.0+PTX"
|
||||
# TGT_CUDA_ARCHS="9.0"
|
||||
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
# OUT_CUDA_ARCHS="8.0+PTX"
|
||||
#
|
||||
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
|
||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||
set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
|
||||
set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
|
||||
|
||||
# handle +PTX suffix: separate base arch for matching, record PTX requests
|
||||
set(_PTX_ARCHS)
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\+PTX$")
|
||||
string(REPLACE "+PTX" "" _base "${_arch}")
|
||||
list(APPEND _PTX_ARCHS "${_base}")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
list(APPEND _SRC_CUDA_ARCHS "${_base}")
|
||||
endif()
|
||||
endforeach()
|
||||
list(REMOVE_DUPLICATES _PTX_ARCHS)
|
||||
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
|
||||
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS_)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0")
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||
# is less or equal to ARCH (but has the same major version since SASS binary
|
||||
# compatibility is only forward compatible within the same major version).
|
||||
foreach(_ARCH ${TGT_CUDA_ARCHS_})
|
||||
foreach(_ARCH ${_TGT_CUDA_ARCHS})
|
||||
set(_TMP_ARCH)
|
||||
# Extract the major version of the target arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
|
||||
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
|
||||
foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
|
||||
# Extract the major version of the source arch
|
||||
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
|
||||
# Check major-version match AND version-less-or-equal
|
||||
# Check version-less-or-equal, and allow PTX arches to match across majors
|
||||
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
|
||||
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
|
||||
set(_TMP_ARCH "${_SRC_ARCH}")
|
||||
endif()
|
||||
else()
|
||||
@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
if(_arch IN_LIST _PTX_ARCHS)
|
||||
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
|
||||
else()
|
||||
list(APPEND _FINAL_ARCHS "${_arch}")
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
Reference in New Issue
Block a user