[core] Support custom ascendc kernels in vllm-ascend (#233)

This PR add custom ascendc kernel rotary_embedding support in
vllm-ascend, related CMakeLists and setuptools is also added in this PR.

Related: https://github.com/vllm-project/vllm-ascend/issues/156

---------

Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-04-03 14:52:34 +08:00
committed by GitHub
parent 14d9a64047
commit ce8259975e
15 changed files with 1378 additions and 9 deletions

View File

@ -62,6 +62,7 @@ jobs:
- name: Install system dependencies
run: |
apt-get -y install `cat packages.txt`
apt-get -y install gcc g++ cmake libnuma-dev
- name: Checkout vllm-project/vllm repo
uses: actions/checkout@v4

102
CMakeLists.txt Normal file
View File

@ -0,0 +1,102 @@
cmake_minimum_required(VERSION 3.16)
project(vllm_ascend_C)
# include(CheckCXXcompilerFlag)
# check_cxx_compiler_flag("-std=c++17", COMPILER_SUPPORTS_CXX17)
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${VLLM_PYTHON_PATH}")
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
find_package(pybind11 REQUIRED)
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
set(VLLM_ASCEND_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}")
find_package(Torch REQUIRED)
set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu")
message(STATUS "Detected SOC version: ${SOC_VERSION}")
if (NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE)
endif()
if (CMAKE_INSTALL_PREFIX STREQUAL /usr/local)
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRINGS "path to install()")
endif()
set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH})
if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
else()
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.")
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
file(GLOB KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp)
ascendc_library(vllm_ascend_kernels SHARED
${KERNEL_FILES}
)
execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))"
OUTPUT_STRIP_TRAILING_WHITESPACE
OUTPUT_VARIABLE TORCH_NPU_PATH
)
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
file(GLOB VLLM_ASCEND_SRC
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
include_directories(
${pybind11_INCLUDE_DIRS}
${PYTHON_INCLUDE_PATH}
${TORCH_INCLUDE_DIRS}
${TORCH_NPU_PATH}/include
${ASCEND_HOME_PATH}/include
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
)
set(
INCLUDES
${TORCH_INCLUDE_DIRS}
${TORCH_NPU_INCLUDE_DIRS}
${ASCEND_HOME_PATH}/include
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
)
pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})
target_link_directories(
vllm_ascend_C
PRIVATE
${TORCH_NPU_PATH}/lib/
${ASCEND_HOME_PATH}/lib64
)
target_link_libraries(
vllm_ascend_C
PUBLIC
${TORCH_LIBRARIES}
libtorch_npu.so
vllm_ascend_kernels
ascendcl
platform
)
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH})

133
cmake/utils.cmake Normal file
View File

@ -0,0 +1,133 @@
#
# Attempt to find the python package that uses the same python executable as
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
#
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
set(Python_EXECUTABLE ${EXECUTABLE})
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
if (NOT Python_FOUND)
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
endif()
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
message(FATAL_ERROR
"Python version (${_VER}) is not one of the supported versions: "
"${_SUPPORTED_VERSIONS_LIST}.")
endif()
message(STATUS "Found python matching: ${EXECUTABLE}.")
endmacro()
#
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
# has trailing whitespace stripped. If an error is encountered when running
# python, a fatal message `ERR_MSG` is issued.
#
function (run_python OUT EXPR ERR_MSG)
execute_process(
COMMAND
"${PYTHON_EXECUTABLE}" "-c" "${EXPR}"
OUTPUT_VARIABLE PYTHON_OUT
RESULT_VARIABLE PYTHON_ERROR_CODE
ERROR_VARIABLE PYTHON_STDERR
OUTPUT_STRIP_TRAILING_WHITESPACE)
if(NOT PYTHON_ERROR_CODE EQUAL 0)
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
endif()
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
endfunction()
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
macro (append_cmake_prefix_path PKG EXPR)
run_python(_PREFIX_PATH
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
endmacro()
# This cmake function is adapted from vllm /Users/ganyi/workspace/vllm-ascend/cmake/utils.cmake
# Define a target named `GPU_MOD_NAME` for a single extension. The
# arguments are:
#
# DESTINATION <dest> - Module destination directory.
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
# etc.
# SOURCES <sources> - List of source files relative to CMakeLists.txt
# directory.
#
# Optional arguments:
#
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
# format.
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
# and `CMAKE_HIP_ARCHITECTURES` for more info.
# ARCHITECTURES will use cmake's defaults if
# not provided.
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
# LIBRARIES <libraries> - Extra link libraries.
# WITH_SOABI - Generate library with python SOABI suffix name.
# USE_SABI <version> - Use python stable api <version>
#
# Note: optimization level/debug info is set via cmake build type.
#
function (define_gpu_extension_target GPU_MOD_NAME)
cmake_parse_arguments(PARSE_ARGV 1
GPU
"WITH_SOABI"
"DESTINATION;LANGUAGE;USE_SABI"
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
# Add hipify preprocessing step when building with HIP/ROCm.
if (GPU_LANGUAGE STREQUAL "HIP")
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
endif()
if (GPU_WITH_SOABI)
set(GPU_WITH_SOABI WITH_SOABI)
else()
set(GPU_WITH_SOABI)
endif()
if (GPU_USE_SABI)
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
else()
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
endif()
if (GPU_LANGUAGE STREQUAL "HIP")
# Make this target dependent on the hipify preprocessor step.
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
endif()
if (GPU_ARCHITECTURES)
set_target_properties(${GPU_MOD_NAME} PROPERTIES
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
endif()
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
target_compile_options(${GPU_MOD_NAME} PRIVATE
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
${GPU_INCLUDE_DIRECTORIES})
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
# dependencies that are not necessary and may not be installed.
if (GPU_LANGUAGE STREQUAL "CUDA")
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
else()
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
endif()
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
endfunction()

View File

@ -0,0 +1,367 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel_operator.h"
#include "kernel_tpipe_impl.h"
#include "kernel_tensor_impl.h"
#include "kernel_type.h"
#include "kernel_operator_intf.h"
#include "inner_interface/inner_kernel_operator_intf.h"
#include <stdio.h>
#include "types.h"
#include "utils.h"
using vllm_ascend::AccType;
using vllm_ascend::local_mem_copy;
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
// NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to
// retrive this size from runtime for more Soc support
static int constexpr loadSize = 1024 * 4;
using dst_t = scalar_t;
using acc_t = typename AccType<scalar_t>::type;
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
using local_scalar_t = AscendC::LocalTensor<scalar_t>;
using local_acc_t = AscendC::LocalTensor<acc_t>;
using local_dst_t = AscendC::LocalTensor<dst_t>;
public:
__aicore__ inline RotaryEmbedding()
{
}
// Allocate buffers for input and output queue and the temp buffer used during kernel compute process,
// this init process happens only in the kernel compute on a single vector core.
__aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride,
const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe)
{
pipe_ = pipe;
rotDim_ = rotDim;
// query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim
queryStride_ = queryStride;
keyStride_ = keyStride;
dstQueryStride_ = dstQueryStride;
dstKeyStride_ = dstKeyStride;
numHeads_ = numHeads;
numKvHeads_ = numKvHeads;
headSize_ = headSize;
embedDim_ = rotDim / 2;
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
// 2 temperary calculation buffer
calcTmpBufferOffset_ = 0;
// 1 upcast buffer for bf16 (headSize)
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
// 1 upcast temp buffer for bf16 (2 * embed_dim)
upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_;
// 2 sin cos upcast buffer for bf16
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
// 2. bf16 path: needs 2 cos sin upcast buffer size
// 3. fp16 path: needs 2 temperary calculation buffer size
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
// 2 temperary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
pipe_->InitBuffer(copyBuf_, loadSize);
}
}
__aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride,
const int64_t queryStride, const int64_t keyStride, const int numHeads,
const int numKvHeads, const int headSize, const int64_t idx)
{
int64_t pos = positions[idx];
cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_);
query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_);
key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_);
queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx,
headSize * numHeads_);
keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_);
}
// compute per head for neox on bf16
template <typename acc_t_, typename std::enable_if<!std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// slice upcast input buffer
local_acc_t upcastBufferX = upcastInputBuffer;
local_acc_t upcastBufferY = upcastBufferX[embedDim_];
// dst x calc
Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_);
Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_);
Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
// dst y calc
Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_);
Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_);
Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
}
// compute per head output for neox
template <typename acc_t_, typename std::enable_if<std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
__aicore__ inline void
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
{
// slice dst buffer
local_dst_t dstX = dst;
local_dst_t dstY = dst[embedDim_];
// slice src buffer
local_scalar_t srcX = src;
local_scalar_t srcY = src[embedDim_];
// slice temp buffer
local_acc_t calcTmpBufferX = calcTmpBuffer;
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
// dst x calc
Mul(calcTmpBufferX, srcX, cos, embedDim_);
Mul(calcTmpBufferY, srcY, sin, embedDim_);
Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_);
// dst y calc
Mul(calcTmpBufferX, srcX, sin, embedDim_);
Mul(calcTmpBufferY, srcY, cos, embedDim_);
Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_);
}
__aicore__ inline void compute_qk(AscendC::GlobalTensor<scalar_t> srcG, AscendC::GlobalTensor<dst_t> dstG,
local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer,
local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride,
int headNumPerLoad)
{
for (int loopNum = 0; loopNum < loopCnt; ++loopNum) {
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = loadStride / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, loadStride);
}
for (int i = 0; i < headNumPerLoad; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
// process tail
{
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
local_dst_t dst = outQue_.AllocTensor<dst_t>();
AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_);
inQue_.EnQue(src);
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
int elem_num = tailHeads * headSize_ / sizeof(scalar_t);
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
} else {
local_mem_copy(dst, srcDeque, tailHeads * headSize_);
}
for (int i = 0; i < tailHeads; ++i) {
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
calcTmpBuffer);
}
outQue_.EnQue(dst);
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_);
outQue_.FreeTensor(dstDeque);
inQue_.FreeTensor(srcDeque);
}
}
__aicore__ inline void compute_function()
{
local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor<scalar_t>();
AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2);
inQueSinCos_.EnQue(cosSinLocal);
local_scalar_t localSinCosDeque = inQueSinCos_.DeQue<scalar_t>();
local_scalar_t localCos = localSinCosDeque;
local_scalar_t localSin = localSinCosDeque[embedDim_];
local_acc_t calcTmpBuffer;
local_acc_t upcastInputBuffer;
local_acc_t upcastTempBuffer;
local_acc_t cosSinUpcastBuffer;
local_acc_t scaleBuffer;
local_acc_t offsetBuffer;
calcTmpBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, calcTmpBufferOffset_);
upcastInputBuffer = calcBuf_.GetWithOffset<acc_t>(headSize_, upcastInputBufferOffset_);
upcastTempBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, upcastTempBufferOffset_);
cosSinUpcastBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, cosSinUpcastBufferOffset_);
local_acc_t cosAccBuffer;
local_acc_t sinAccBuffer;
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_);
cosAccBuffer = cosSinUpcastBuffer;
sinAccBuffer = cosSinUpcastBuffer[embedDim_];
} else {
cosAccBuffer = localCos;
sinAccBuffer = localSin;
}
constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t);
int64_t headNumPerLoad = loadSizeByElem / headSize_;
int64_t loopCnt = numHeads_ / headNumPerLoad;
int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad;
int64_t loadStride = headNumPerLoad * headSize_;
int64_t loopCntKv = numKvHeads_ / headNumPerLoad;
int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad;
compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer,
calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad);
compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer,
loopCntKv, tailHeadsKv, loadStride, headNumPerLoad);
inQueSinCos_.FreeTensor(localSinCosDeque);
}
private:
AscendC::TPipe *pipe_;
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQue_, inQueSinCos_;
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQue_;
AscendC::TBuf<AscendC::TPosition::VECCALC> calcBuf_;
AscendC::TBuf<AscendC::TPosition::VECCALC> copyBuf_;
AscendC::GlobalTensor<dst_t> queryDst_;
AscendC::GlobalTensor<dst_t> keyDst_;
AscendC::GlobalTensor<scalar_t> query_;
AscendC::GlobalTensor<scalar_t> key_;
AscendC::GlobalTensor<scalar_t> cosSin_;
int rotDim_;
int embedDim_;
int64_t queryStride_;
int64_t keyStride_;
int64_t dstQueryStride_;
int64_t dstKeyStride_;
int numHeads_;
int numKvHeads_;
int headSize_;
int calcTmpBufferOffset_;
int upcastInputBufferOffset_;
int upcastTempBufferOffset_;
int cosSinUpcastBufferOffset_;
int tempBufferSize_;
};
// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp
// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure
#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \
extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \
__gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \
__gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \
const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \
const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \
{ \
AscendC::TPipe pipe; \
RotaryEmbedding<TYPE, NEOX> op{}; \
op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \
for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \
op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
queryStride, keyStride, numHeads, numKvHeads, headSize, i); \
op.compute_function(); \
} \
}
#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false);
// Declare all the kernel entry here
ROPE_CUSTOM_KERNEL_DECLARE(half)
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
namespace vllm_ascend {
#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \
if (isNeox) \
rope_custom_true_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \
else \
rope_custom_false_##TYPE<<<blockDim, nullptr, stream>>>( \
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
// maximum number for runtime to launch a ascendc kernel.
// we use this to constrain the maximum number of block size
static const int64_t maxParallelSize = 65535;
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum)
{
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
if (type == AscendType::FP16) {
ROTARY_EMBEDDING_KERNEL_CALL(half);
} else if (type == AscendType::BF16) {
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
} else {
return;
}
}
} // namespace vllm_ascend

25
csrc/kernels/types.h Normal file
View File

@ -0,0 +1,25 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace vllm_ascend {
enum struct AscendType {
FP16 = 0,
BF16 = 1,
FP32 = 2,
};
}

49
csrc/kernels/utils.h Normal file
View File

@ -0,0 +1,49 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "kernel_type.h"
namespace vllm_ascend {
template <typename scalar_t> struct AccType;
template <> struct AccType<bfloat16_t> {
using type = float;
};
template <> struct AccType<half> {
using type = half;
};
template <> struct AccType<float> {
using type = float;
};
template <> struct AccType<int8_t> {
using type = int;
};
template <typename scalar_t>
__aicore__ inline void local_mem_copy(AscendC::LocalTensor<scalar_t> dst, AscendC::LocalTensor<scalar_t> src, int size)
{
constexpr int loadSize = 256 / sizeof(scalar_t);
int loopCnt = size / loadSize;
int tailSize = size % loadSize;
if (loopCnt)
AscendC::Copy(dst, src, loadSize, loopCnt, {1, 1, 8, 8});
AscendC::Copy(dst[loopCnt * loadSize], src[loopCnt * loadSize], tailSize, 1, {1, 1, 8, 8});
}
} // namespace vllm_ascend

32
csrc/ops.h Normal file
View File

@ -0,0 +1,32 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <optional>
#include <torch/library.h>
#include <vector>
#include "kernels/types.h"
namespace vllm_ascend {
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
uint32_t aivNum);
}

108
csrc/torch_binding.cpp Normal file
View File

@ -0,0 +1,108 @@
/*
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/extension.h>
#include <torch/library.h>
#include <torch/version.h>
#include <torch_npu/csrc/core/npu/NPUStream.h>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/npu/Module.h>
#include <pybind11/pybind11.h>
#include "acl/acl.h"
#include "tiling/platform/platform_ascendc.h"
#include "aclnn/opdev/platform.h"
#include "ops.h"
#include "utils.h"
namespace vllm_ascend {
void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
{
int32_t deviceId = 0;
int64_t num_tokens = positions.numel();
int positions_ndim = positions.dim();
TORCH_CHECK(
positions_ndim == 1 || positions_ndim == 2,
"positions must have shape [num_tokens] or [batch_size, seq_len]");
if (positions_ndim == 1) {
TORCH_CHECK(
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
"query, key and positions must have the same number of tokens");
}
if (positions_ndim == 2) {
TORCH_CHECK(
query.size(0) == positions.size(0) &&
key.size(0) == positions.size(0) &&
query.size(1) == positions.size(1) &&
key.size(1) == positions.size(1),
"query, key and positions must have the same batch_size and seq_len");
}
int query_hidden_size = query.numel() / num_tokens;
int key_hidden_size = key.numel() / num_tokens;
TORCH_CHECK(query_hidden_size % head_size == 0);
TORCH_CHECK(key_hidden_size % head_size == 0);
// Make sure query and key have consistent number of heads
int num_heads = query_hidden_size / head_size;
int num_kv_heads = key_hidden_size / head_size;
TORCH_CHECK(num_heads % num_kv_heads == 0);
int rot_dim = cos_sin_cache.size(1);
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
void *query_ptr = query.data_ptr();
void *key_ptr = key.data_ptr();
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
int64_t query_stride = query.stride(-2);
int64_t key_stride = key.stride(-2);
at::ScalarType scalar_type = query.scalar_type();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("rotary_embedding");
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
num_heads, num_kv_heads, head_size]() -> int {
auto dtype_num = get_dtype_from_torch(scalar_type);
fe::PlatFormInfos platform_infos;
int device_id = 0;
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr,
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride,
key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
return 0;
});
cmd.Run();
return ;
}
} // namespace vllm_ascend
TORCH_LIBRARY_EXPAND(_C, ops)
{
// vLLM-Ascend custom ops
// Rotary embedding
// Apply GPT-NeoX style rotary embedding to query and key.
ops.def(
"rotary_embedding(Tensor positions, Tensor! query,"
" Tensor! key, int head_size,"
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
}
REGISTER_EXTENSION(_C)

43
csrc/utils.h Normal file
View File

@ -0,0 +1,43 @@
#pragma once
#include "kernels/types.h"
#include <c10/core/ScalarType.h>
#include <Python.h>
#define _CONCAT(A, B) A##B
#define CONCAT(A, B) _CONCAT(A, B)
#define _STRINGIFY(A) #A
#define STRINGIFY(A) _STRINGIFY(A)
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
STRINGIFY(NAME), nullptr, 0, nullptr}; \
return PyModule_Create(&module); \
}
namespace vllm_ascend {
AscendType get_dtype_from_torch(at::ScalarType scalarType)
{
if (scalarType == at::ScalarType::Float) {
return AscendType::FP32;
} else if (scalarType == at::ScalarType::BFloat16) {
return AscendType::BF16;
} else {
return AscendType::FP16;
}
}
} // namespace vllm_ascend

View File

@ -3,9 +3,12 @@
requires = [
"setuptools>=64",
"setuptools-scm>=8",
"cmake>=3.26",
"pybind11",
"decorator",
"pyyaml",
"scipy",
"torch-npu >= 2.5.1rc1"
"torch_npu >= 2.5.1rc1",
"torch >= 2.5.1"
]
build-backend = "setuptools.build_meta"

View File

@ -1,6 +1,7 @@
decorator
pyyaml
scipy
pybind11
setuptools
setuptools-scm
numpy==1.26.4

278
setup.py
View File

@ -17,12 +17,265 @@
# limitations under the License.
#
import importlib.util
import logging
import os
from typing import List
import subprocess
import sys
from sysconfig import get_paths
from typing import Dict, List
from setuptools import find_packages, setup
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.develop import develop
from setuptools.command.install import install
from setuptools_scm import get_version
def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
ROOT_DIR = os.path.dirname(__file__)
logger = logging.getLogger(__name__)
def check_or_set_default_env(cmake_args,
env_name,
env_variable,
default_path=""):
if env_variable is None:
logging.warning(
f"No {env_name} found in your environment, pleause try to set {env_name} "
"if you customize the installation path of this library, otherwise default "
"path will be adapted during build this project")
logging.warning(f"Set default {env_name}: {default_path}")
env_variable = default_path
else:
logging.info(f"Found existing {env_name}: {env_variable}")
# cann package seems will check this environments in cmake, need write this env variable back.
if env_name == "ASCEND_HOME_PATH":
os.environ["ASCEND_HOME_PATH"] = env_variable
cmake_args += [f"-D{env_name}={env_variable}"]
return cmake_args
envs = load_module_from_path("envs",
os.path.join(ROOT_DIR, "vllm_ascend", "envs.py"))
class CMakeExtension(Extension):
def __init__(self,
name: str,
cmake_lists_dir: str = ".",
**kwargs) -> None:
super().__init__(name, sources=[], py_limited_api=False, **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
class cmake_build_ext(build_ext):
# A dict of extension directories that have been configured.
did_config: Dict[str, bool] = {}
#
# Determine number of compilation jobs
#
def compute_num_jobs(self):
# `num_jobs` is either the value of the MAX_JOBS environment variable
# (if defined) or the number of CPUs available.
num_jobs = envs.MAX_JOBS
if num_jobs is not None:
num_jobs = int(num_jobs)
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
else:
try:
# os.sched_getaffinity() isn't universally available, so fall
# back to os.cpu_count() if we get an error here.
num_jobs = len(os.sched_getaffinity(0))
except AttributeError:
num_jobs = os.cpu_count()
num_jobs = max(1, num_jobs)
return num_jobs
#
# Perform cmake configuration for a single extension.
#
def configure(self, ext: CMakeExtension) -> None:
build_temp = self.build_temp
os.makedirs(build_temp, exist_ok=True)
source_dir = os.path.abspath(ROOT_DIR)
python_executable = sys.executable
cmake_args = ["cmake"]
# Default use release mode to compile the csrc code
# Turbo now support compiled with Release, Debug and RelWithDebugInfo
if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [
"Debug",
"Release",
"RelWithDebugInfo",
]:
envs.CMAKE_BUILD_TYPE = "Release"
cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"]
# Default dump the compile commands for lsp
cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"]
if envs.VERBOSE:
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
# find ASCEND_HOME_PATH
check_or_set_default_env(
cmake_args,
"ASCEND_HOME_PATH",
envs.ASCEND_HOME_PATH,
"/usr/local/Ascend/ascend-toolkit/latest",
)
# find PYTHON_EXECUTABLE
check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE",
sys.executable)
# find PYTHON_INCLUDE_PATH
check_or_set_default_env(cmake_args, "PYHTON_INCLUDE_PATH",
get_paths()["include"])
# ccache and ninja can not be applied at ascendc kernels now
try:
# if pybind11 is installed via pip
pybind11_cmake_path = (subprocess.check_output(
[python_executable, "-m", "pybind11",
"--cmake"]).decode().strip())
except subprocess.CalledProcessError as e:
# else specify pybind11 path installed from source code on CI container
raise RuntimeError(f"CMake configuration failed: {e}")
# try retrive soc version from npu-smi
soc_command = [
"bash",
"-c",
"npu-smi info | grep OK | awk '{print $3}' | head -n 1",
]
try:
soc_version = subprocess.check_output(soc_command,
text=True).strip()
soc_version = soc_version.split("-")[0]
soc_version = "Ascend" + soc_version
except subprocess.CalledProcessError as e:
raise RuntimeError(f"Retrive Soc version failed: {e}")
# add SOC_VERSION
cmake_args += [f"-DSOC_VERSION={soc_version}"]
install_path = os.path.join(ROOT_DIR, self.build_lib)
if isinstance(self.distribution.get_command_obj("develop"), develop):
install_path = os.path.join(ROOT_DIR, "vllm_ascend")
# add CMAKE_INSTALL_PATH
cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"]
cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"]
# Override the base directory for FetchContent downloads to $ROOT/.deps
# This allows sharing dependencies between profiles,
# and plays more nicely with sccache.
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
build_tool = []
# TODO(ganyi): ninja and ccache support for ascend c auto codegen. now we can only use make build
# if which('ninja') is not None:
# build_tool += ['-G', 'Ninja']
# Default build tool to whatever cmake picks.
cmake_args += [source_dir]
logging.info(f"cmake config command: {cmake_args}")
try:
subprocess.check_call(cmake_args, cwd=self.build_temp)
except subprocess.CalledProcessError as e:
raise RuntimeError(f"CMake configuration failed: {e}")
subprocess.check_call(
["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
cwd=self.build_temp,
)
def build_extensions(self) -> None:
if envs.COMPILE_CUSTOM_KERNELS is None:
return
# Ensure that CMake is present and working
try:
subprocess.check_output(["cmake", "--version"])
except OSError as e:
raise RuntimeError(f"Cannot find CMake executable: {e}")
# Create build directory if it does not exist.
if not os.path.exists(self.build_temp):
os.makedirs(self.build_temp)
targets = []
os.makedirs(os.path.join(self.build_lib, "vllm_ascend"), exist_ok=True)
def target_name(s: str) -> str:
return s.removeprefix("vllm_ascend.")
# Build all the extensions
for ext in self.extensions:
self.configure(ext)
targets.append(target_name(ext.name))
num_jobs = self.compute_num_jobs()
build_args = [
"--build",
".",
f"-j={num_jobs}",
*[f"--target={name}" for name in targets],
]
try:
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
except OSError as e:
raise RuntimeError(f"Build library failed: {e}")
# Install the libraries
install_args = [
"cmake",
"--install",
".",
]
try:
subprocess.check_call(install_args, cwd=self.build_temp)
except OSError as e:
raise RuntimeError(f"Install library failed: {e}")
# copy back to build folder for editable build
if isinstance(self.distribution.get_command_obj("develop"), develop):
import shutil
for root, _, files in os.walk(self.build_temp):
for file in files:
if file.endswith(".so"):
src_path = os.path.join(root, file)
dst_path = os.path.join(self.build_lib, "vllm_ascend",
file)
shutil.copy(src_path, dst_path)
print(f"Copy: {src_path} -> {dst_path}")
def run(self):
# First, run the standard build_ext command to compile the extensions
super().run()
class custom_install(install):
def run(self):
self.run_command("build_ext")
install.run(self)
ROOT_DIR = os.path.dirname(__file__)
try:
VERSION = get_version(write_to="vllm_ascend/_version.py")
@ -31,6 +284,10 @@ except LookupError:
# only checks out the commit. In this case, we set a dummy version.
VERSION = "0.0.0"
ext_modules = []
if envs.COMPILE_CUSTOM_KERNELS is not None:
ext_modules = [CMakeExtension(name="vllm_ascend.vllm_ascend_C")]
def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)
@ -69,8 +326,10 @@ def get_requirements() -> List[str]:
return requirements
cmdclass = {"build_ext": cmake_build_ext, "install": custom_install}
setup(
name='vllm_ascend',
name="vllm_ascend",
# Follow:
# https://packaging.python.org/en/latest/specifications/version-specifiers
version=VERSION,
@ -95,12 +354,15 @@ setup(
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Information Analysis",
],
packages=find_packages(exclude=("docs", "examples", "tests*")),
packages=find_packages(exclude=("docs", "examples", "tests*", "csrc")),
python_requires=">=3.9",
install_requires=get_requirements(),
ext_modules=ext_modules,
cmdclass=cmdclass,
extras_require={},
entry_points={
'vllm.platform_plugins': ["ascend = vllm_ascend:register"],
'vllm.general_plugins':
["ascend_enhanced_model = vllm_ascend:register_model"]
})
"vllm.platform_plugins": ["ascend = vllm_ascend:register"],
"vllm.general_plugins":
["ascend_enhanced_model = vllm_ascend:register_model"],
},
)

View File

@ -0,0 +1,204 @@
# Adapted from
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
# Copyright 2023 The vLLM team.
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
from typing import Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
import torch_npu # noqa: F401
import vllm_ascend.platform # noqa: F401
# Only Neox style true scenario is supported for now
IS_NEOX_STYLE = [True]
DTYPES = [torch.half]
HEAD_SIZES = [64, 96, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [17] # Arbitrary values for testing
BATCH_SIZES = [5] # Arbitrary values for testing
SEQ_LENS = [11, 4096] # Arbitrary values for testing
SEEDS = [0]
DEVICES = [f"npu:{0}"]
# Set tolerance to 1 for quant ops
DEFAULT_ATOL = 1e-3
DEFAULT_RTOL = 1e-3
def _apply_rotary_emb(
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> torch.Tensor:
"""
Args:
x: [num_tokens, num_heads, head_size]
cos: [num_tokens, head_size // 2]
sin: [num_tokens, head_size // 2]
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
positional embeddings.
"""
cos = cos.unsqueeze(-2).to(x.dtype)
sin = sin.unsqueeze(-2).to(x.dtype)
if is_neox_style:
x1, x2 = torch.chunk(x, 2, dim=-1)
else:
x1 = x[..., ::2]
x2 = x[..., 1::2]
o1 = x1 * cos - x2 * sin
o2 = x2 * cos + x1 * sin
if is_neox_style:
return torch.cat((o1, o2), dim=-1)
else:
return torch.stack((o1, o2), dim=-1).flatten(-2)
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding."""
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
base: int,
is_neox_style: bool,
dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
cache = self._compute_cos_sin_cache()
cache = cache.to(dtype)
self.cos_sin_cache: torch.Tensor
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
"""Compute the inverse frequency."""
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A PyTorch-native implementation of forward()."""
if offsets is not None:
positions = positions + offsets
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
return query, key
# test with leading dimension and merge seqlen and batch_size as num_tokens
# TODO(ganyi): open this test in the future
@pytest.mark.skip(
reason=
"skip this test by default for now because of ci issue, will enable it in the future"
)
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode()
def test_rotary_embedding_quant_with_leading_dim(
is_neox_style: bool,
batch_size: int,
seq_len: int,
num_heads: int,
head_size: int,
rotary_dim: Optional[int],
dtype: torch.dtype,
seed: int,
device: str,
max_position: int = 8192,
base: int = 10000,
) -> None:
if rotary_dim is None:
rotary_dim = head_size
torch.set_default_device(device)
if rotary_dim is None:
rotary_dim = head_size
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
rope = rope.to(dtype=dtype)
num_tokens = batch_size * seq_len
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
qkv_tensor = torch.randn(num_tokens,
num_heads * head_size * 3,
dtype=dtype)
query, key, _ = qkv_tensor.split(
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
dim=-1,
)
ref_query, ref_key = rope.forward_native(positions, query, key)
torch.ops._C.rotary_embedding(
positions,
query,
key,
rope.head_size,
rope.cos_sin_cache,
rope.is_neox_style,
)
# Compare the results.
torch.testing.assert_close(query,
ref_query,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)
torch.testing.assert_close(key,
ref_key,
atol=DEFAULT_ATOL,
rtol=DEFAULT_RTOL)

25
vllm_ascend/envs.py Normal file
View File

@ -0,0 +1,25 @@
import os
from typing import Any, Callable, Dict
env_variables: Dict[str, Callable[[], Any]] = {
# max compile thread num
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
"COMPILE_CUSTOM_KERNELS":
lambda: os.getenv("COMPILE_CUSTOM_KERNELS", None),
# If set, vllm-ascend will print verbose logs during compliation
"VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))),
"ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
"LD_LIBRARY_PATH": lambda: os.getenv("LD_LIBRARY_PATH", None),
}
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in env_variables:
return env_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def __dir__():
return list(env_variables.keys())

View File

@ -15,6 +15,7 @@
# limitations under the License.
#
import logging
import os
from typing import TYPE_CHECKING, Optional, Tuple
@ -23,6 +24,19 @@ import torch_npu # noqa: F401
import vllm.envs as envs
from vllm.config import CompilationLevel
from vllm.logger import init_logger
try:
# register custom ops into torch_library here
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
except ImportError as e:
if not str(
e
) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)":
logging.warning(
"Warning: Failed to register custom ops, all custom ops will be disabled"
)
from vllm.platforms import Platform, PlatformEnum
if TYPE_CHECKING: