From ce8259975e0befc4830152f631c826ccc046e808 Mon Sep 17 00:00:00 2001 From: Pleaplusone <38376071+ganyi1996ppo@users.noreply.github.com> Date: Thu, 3 Apr 2025 14:52:34 +0800 Subject: [PATCH] [core] Support custom ascendc kernels in vllm-ascend (#233) This PR add custom ascendc kernel rotary_embedding support in vllm-ascend, related CMakeLists and setuptools is also added in this PR. Related: https://github.com/vllm-project/vllm-ascend/issues/156 --------- Signed-off-by: ganyi --- .github/workflows/vllm_ascend_test_main.yaml | 1 + CMakeLists.txt | 102 ++++++ cmake/utils.cmake | 133 +++++++ csrc/kernels/pos_encoding_kernels.cpp | 367 +++++++++++++++++++ csrc/kernels/types.h | 25 ++ csrc/kernels/utils.h | 49 +++ csrc/ops.h | 32 ++ csrc/torch_binding.cpp | 108 ++++++ csrc/utils.h | 43 +++ pyproject.toml | 5 +- requirements.txt | 1 + setup.py | 278 +++++++++++++- tests/ops/test_rotary_embedding.py | 204 +++++++++++ vllm_ascend/envs.py | 25 ++ vllm_ascend/platform.py | 14 + 15 files changed, 1378 insertions(+), 9 deletions(-) create mode 100644 CMakeLists.txt create mode 100644 cmake/utils.cmake create mode 100644 csrc/kernels/pos_encoding_kernels.cpp create mode 100644 csrc/kernels/types.h create mode 100644 csrc/kernels/utils.h create mode 100644 csrc/ops.h create mode 100644 csrc/torch_binding.cpp create mode 100644 csrc/utils.h create mode 100644 tests/ops/test_rotary_embedding.py create mode 100644 vllm_ascend/envs.py diff --git a/.github/workflows/vllm_ascend_test_main.yaml b/.github/workflows/vllm_ascend_test_main.yaml index 7dcc8a287..4c6886d63 100644 --- a/.github/workflows/vllm_ascend_test_main.yaml +++ b/.github/workflows/vllm_ascend_test_main.yaml @@ -62,6 +62,7 @@ jobs: - name: Install system dependencies run: | apt-get -y install `cat packages.txt` + apt-get -y install gcc g++ cmake libnuma-dev - name: Checkout vllm-project/vllm repo uses: actions/checkout@v4 diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..1814e4c98 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,102 @@ +cmake_minimum_required(VERSION 3.16) +project(vllm_ascend_C) + +# include(CheckCXXcompilerFlag) +# check_cxx_compiler_flag("-std=c++17", COMPILER_SUPPORTS_CXX17) + + +include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) + +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + +set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") + +find_package(pybind11 REQUIRED) + +append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path") +set(VLLM_ASCEND_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}") + +find_package(Torch REQUIRED) + +set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu") +message(STATUS "Detected SOC version: ${SOC_VERSION}") + +if (NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE) +endif() + +if (CMAKE_INSTALL_PREFIX STREQUAL /usr/local) + set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRINGS "path to install()") +endif() + +set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH}) +if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake) +elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) + set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake) +else() + message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.") +endif() + +include(${ASCENDC_CMAKE_DIR}/ascendc.cmake) +file(GLOB KERNEL_FILES +${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp) + +ascendc_library(vllm_ascend_kernels SHARED + ${KERNEL_FILES} +) + +execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE TORCH_NPU_PATH +) +message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}") + +file(GLOB VLLM_ASCEND_SRC +${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp) + +include_directories( + ${pybind11_INCLUDE_DIRS} + ${PYTHON_INCLUDE_PATH} + ${TORCH_INCLUDE_DIRS} + ${TORCH_NPU_PATH}/include + ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform + ${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform +) + +set( + INCLUDES + ${TORCH_INCLUDE_DIRS} + ${TORCH_NPU_INCLUDE_DIRS} + ${ASCEND_HOME_PATH}/include + ${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform +) + +pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC}) + +target_link_directories( + vllm_ascend_C + PRIVATE + ${TORCH_NPU_PATH}/lib/ + ${ASCEND_HOME_PATH}/lib64 +) + +target_link_libraries( + vllm_ascend_C + PUBLIC + ${TORCH_LIBRARIES} + libtorch_npu.so + vllm_ascend_kernels + ascendcl + platform +) + +target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib") + +install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH}) + + diff --git a/cmake/utils.cmake b/cmake/utils.cmake new file mode 100644 index 000000000..62078fd31 --- /dev/null +++ b/cmake/utils.cmake @@ -0,0 +1,133 @@ +# +# Attempt to find the python package that uses the same python executable as +# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`. +# +macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS) + file(REAL_PATH ${EXECUTABLE} EXECUTABLE) + set(Python_EXECUTABLE ${EXECUTABLE}) + find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule) + if (NOT Python_FOUND) + message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.") + endif() + set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}") + set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN}) + if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST) + message(FATAL_ERROR + "Python version (${_VER}) is not one of the supported versions: " + "${_SUPPORTED_VERSIONS_LIST}.") + endif() + message(STATUS "Found python matching: ${EXECUTABLE}.") +endmacro() + +# +# Run `EXPR` in python. The standard output of python is stored in `OUT` and +# has trailing whitespace stripped. If an error is encountered when running +# python, a fatal message `ERR_MSG` is issued. +# +function (run_python OUT EXPR ERR_MSG) + execute_process( + COMMAND + "${PYTHON_EXECUTABLE}" "-c" "${EXPR}" + OUTPUT_VARIABLE PYTHON_OUT + RESULT_VARIABLE PYTHON_ERROR_CODE + ERROR_VARIABLE PYTHON_STDERR + OUTPUT_STRIP_TRAILING_WHITESPACE) + + if(NOT PYTHON_ERROR_CODE EQUAL 0) + message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}") + endif() + set(${OUT} ${PYTHON_OUT} PARENT_SCOPE) +endfunction() + +# Run `EXPR` in python after importing `PKG`. Use the result of this to extend +# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported. +macro (append_cmake_prefix_path PKG EXPR) + run_python(_PREFIX_PATH + "import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path") + list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH}) +endmacro() + + +# This cmake function is adapted from vllm /Users/ganyi/workspace/vllm-ascend/cmake/utils.cmake +# Define a target named `GPU_MOD_NAME` for a single extension. The +# arguments are: +# +# DESTINATION - Module destination directory. +# LANGUAGE - The GPU language for this module, e.g CUDA, HIP, +# etc. +# SOURCES - List of source files relative to CMakeLists.txt +# directory. +# +# Optional arguments: +# +# ARCHITECTURES - A list of target GPU architectures in cmake +# format. +# Refer `CMAKE_CUDA_ARCHITECTURES` documentation +# and `CMAKE_HIP_ARCHITECTURES` for more info. +# ARCHITECTURES will use cmake's defaults if +# not provided. +# COMPILE_FLAGS - Extra compiler flags passed to NVCC/hip. +# INCLUDE_DIRECTORIES - Extra include directories. +# LIBRARIES - Extra link libraries. +# WITH_SOABI - Generate library with python SOABI suffix name. +# USE_SABI - Use python stable api +# +# Note: optimization level/debug info is set via cmake build type. +# +function (define_gpu_extension_target GPU_MOD_NAME) + cmake_parse_arguments(PARSE_ARGV 1 + GPU + "WITH_SOABI" + "DESTINATION;LANGUAGE;USE_SABI" + "SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES") + + # Add hipify preprocessing step when building with HIP/ROCm. + if (GPU_LANGUAGE STREQUAL "HIP") + hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}") + endif() + + if (GPU_WITH_SOABI) + set(GPU_WITH_SOABI WITH_SOABI) + else() + set(GPU_WITH_SOABI) + endif() + + if (GPU_USE_SABI) + Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}") + else() + Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}") + endif() + + if (GPU_LANGUAGE STREQUAL "HIP") + # Make this target dependent on the hipify preprocessor step. + add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME}) + endif() + + if (GPU_ARCHITECTURES) + set_target_properties(${GPU_MOD_NAME} PROPERTIES + ${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}") + endif() + + set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17) + + target_compile_options(${GPU_MOD_NAME} PRIVATE + $<$:${GPU_COMPILE_FLAGS}>) + + target_compile_definitions(${GPU_MOD_NAME} PRIVATE + "-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}") + + target_include_directories(${GPU_MOD_NAME} PRIVATE csrc + ${GPU_INCLUDE_DIRECTORIES}) + + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) + + # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of + # dependencies that are not necessary and may not be installed. + if (GPU_LANGUAGE STREQUAL "CUDA") + target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver) + else() + target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) + endif() + + install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) +endfunction() diff --git a/csrc/kernels/pos_encoding_kernels.cpp b/csrc/kernels/pos_encoding_kernels.cpp new file mode 100644 index 000000000..cce08ca97 --- /dev/null +++ b/csrc/kernels/pos_encoding_kernels.cpp @@ -0,0 +1,367 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel_operator.h" +#include "kernel_tpipe_impl.h" +#include "kernel_tensor_impl.h" +#include "kernel_type.h" +#include "kernel_operator_intf.h" +#include "inner_interface/inner_kernel_operator_intf.h" +#include +#include "types.h" +#include "utils.h" + + +using vllm_ascend::AccType; +using vllm_ascend::local_mem_copy; +template class RotaryEmbedding { + // NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to + // retrive this size from runtime for more Soc support + static int constexpr loadSize = 1024 * 4; + using dst_t = scalar_t; + using acc_t = typename AccType::type; + // only half tensor have cast instruct to int8, hardcode acc_dst_t as half + using local_scalar_t = AscendC::LocalTensor; + using local_acc_t = AscendC::LocalTensor; + using local_dst_t = AscendC::LocalTensor; + +public: + __aicore__ inline RotaryEmbedding() + { + } + + // Allocate buffers for input and output queue and the temp buffer used during kernel compute process, + // this init process happens only in the kernel compute on a single vector core. + __aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst, + __gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache, + const int rotDim, const int64_t dstQueryStride, + const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride, + const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe) + { + pipe_ = pipe; + rotDim_ = rotDim; + // query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim + queryStride_ = queryStride; + keyStride_ = keyStride; + dstQueryStride_ = dstQueryStride; + dstKeyStride_ = dstKeyStride; + numHeads_ = numHeads; + numKvHeads_ = numKvHeads; + headSize_ = headSize; + embedDim_ = rotDim / 2; + + pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */); + pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */); + pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */); + // 2 temperary calculation buffer + calcTmpBufferOffset_ = 0; + // 1 upcast buffer for bf16 (headSize) + upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2; + // 1 upcast temp buffer for bf16 (2 * embed_dim) + upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_; + // 2 sin cos upcast buffer for bf16 + cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_; + // 2. bf16 path: needs 2 cos sin upcast buffer size + // 3. fp16 path: needs 2 temperary calculation buffer size + tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t); + // need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case + // 2 temperary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp + // buffer(headSize int8), 1 dst_temp buffer(headSize, int32) + pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */); + if constexpr (!std::is_same_v) { + pipe_->InitBuffer(copyBuf_, loadSize); + } + } + __aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst, + __gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache, + const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride, + const int64_t queryStride, const int64_t keyStride, const int numHeads, + const int numKvHeads, const int headSize, const int64_t idx) + { + int64_t pos = positions[idx]; + cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_); + query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_); + key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_); + queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx, + headSize * numHeads_); + keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_); + } + + // compute per head for neox on bf16 + template , void>::type * = nullptr> + __aicore__ inline void + neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor sin, AscendC::LocalTensor cos, + AscendC::LocalTensor upcastInputBuffer, AscendC::LocalTensor calcTmpBuffer) + { + // slice dst + local_dst_t dstX = dst; + local_dst_t dstY = dst[embedDim_]; + + // slice src + local_scalar_t srcX = src; + local_scalar_t srcY = src[embedDim_]; + + // slice temp buffer + local_acc_t calcTmpBufferX = calcTmpBuffer; + local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_]; + + // slice upcast input buffer + local_acc_t upcastBufferX = upcastInputBuffer; + local_acc_t upcastBufferY = upcastBufferX[embedDim_]; + + // dst x calc + Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_); + Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_); + Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_); + Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_); + Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_); + + // dst y calc + Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_); + Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_); + Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_); + Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_); + } + + // compute per head output for neox + template , void>::type * = nullptr> + __aicore__ inline void + neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor sin, AscendC::LocalTensor cos, + AscendC::LocalTensor upcastInputBuffer, AscendC::LocalTensor calcTmpBuffer) + { + // slice dst buffer + local_dst_t dstX = dst; + local_dst_t dstY = dst[embedDim_]; + // slice src buffer + local_scalar_t srcX = src; + local_scalar_t srcY = src[embedDim_]; + // slice temp buffer + local_acc_t calcTmpBufferX = calcTmpBuffer; + local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_]; + + // dst x calc + Mul(calcTmpBufferX, srcX, cos, embedDim_); + Mul(calcTmpBufferY, srcY, sin, embedDim_); + Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_); + + // dst y calc + Mul(calcTmpBufferX, srcX, sin, embedDim_); + Mul(calcTmpBufferY, srcY, cos, embedDim_); + Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_); + } + + __aicore__ inline void compute_qk(AscendC::GlobalTensor srcG, AscendC::GlobalTensor dstG, + local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer, + local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride, + int headNumPerLoad) + { + for (int loopNum = 0; loopNum < loopCnt; ++loopNum) { + local_scalar_t src = inQue_.AllocTensor(); + local_dst_t dst = outQue_.AllocTensor(); + AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride); + inQue_.EnQue(src); + + local_scalar_t srcDeque = inQue_.DeQue(); + if constexpr (!std::is_same_v) { + int elem_num = loadStride / sizeof(scalar_t); + AscendC::LocalTensor upBuffer = copyBuf_.GetWithOffset(elem_num, 0); + Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num); + Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num); + } else { + local_mem_copy(dst, srcDeque, loadStride); + } + for (int i = 0; i < headNumPerLoad; ++i) { + neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer, + calcTmpBuffer); + } + outQue_.EnQue(dst); + local_dst_t dstDeque = outQue_.DeQue(); + AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride); + outQue_.FreeTensor(dstDeque); + inQue_.FreeTensor(srcDeque); + } + // process tail + { + local_scalar_t src = inQue_.AllocTensor(); + local_dst_t dst = outQue_.AllocTensor(); + + AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_); + inQue_.EnQue(src); + local_scalar_t srcDeque = inQue_.DeQue(); + + if constexpr (!std::is_same_v) { + int elem_num = tailHeads * headSize_ / sizeof(scalar_t); + AscendC::LocalTensor upBuffer = copyBuf_.GetWithOffset(elem_num, 0); + Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num); + Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num); + } else { + local_mem_copy(dst, srcDeque, tailHeads * headSize_); + } + + for (int i = 0; i < tailHeads; ++i) { + neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer, + calcTmpBuffer); + } + outQue_.EnQue(dst); + local_dst_t dstDeque = outQue_.DeQue(); + AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_); + outQue_.FreeTensor(dstDeque); + inQue_.FreeTensor(srcDeque); + } + } + + __aicore__ inline void compute_function() + { + local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor(); + + AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2); + + inQueSinCos_.EnQue(cosSinLocal); + local_scalar_t localSinCosDeque = inQueSinCos_.DeQue(); + local_scalar_t localCos = localSinCosDeque; + local_scalar_t localSin = localSinCosDeque[embedDim_]; + + local_acc_t calcTmpBuffer; + local_acc_t upcastInputBuffer; + local_acc_t upcastTempBuffer; + local_acc_t cosSinUpcastBuffer; + local_acc_t scaleBuffer; + local_acc_t offsetBuffer; + calcTmpBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, calcTmpBufferOffset_); + upcastInputBuffer = calcBuf_.GetWithOffset(headSize_, upcastInputBufferOffset_); + upcastTempBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, upcastTempBufferOffset_); + cosSinUpcastBuffer = calcBuf_.GetWithOffset(embedDim_ * 2, cosSinUpcastBufferOffset_); + + local_acc_t cosAccBuffer; + local_acc_t sinAccBuffer; + + if constexpr (!std::is_same_v) { + Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_); + cosAccBuffer = cosSinUpcastBuffer; + sinAccBuffer = cosSinUpcastBuffer[embedDim_]; + } else { + cosAccBuffer = localCos; + sinAccBuffer = localSin; + } + + constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t); + int64_t headNumPerLoad = loadSizeByElem / headSize_; + int64_t loopCnt = numHeads_ / headNumPerLoad; + int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad; + int64_t loadStride = headNumPerLoad * headSize_; + int64_t loopCntKv = numKvHeads_ / headNumPerLoad; + int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad; + compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, + calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad); + + compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer, + loopCntKv, tailHeadsKv, loadStride, headNumPerLoad); + + inQueSinCos_.FreeTensor(localSinCosDeque); + } + +private: + AscendC::TPipe *pipe_; + AscendC::TQue inQue_, inQueSinCos_; + AscendC::TQue outQue_; + AscendC::TBuf calcBuf_; + AscendC::TBuf copyBuf_; + AscendC::GlobalTensor queryDst_; + AscendC::GlobalTensor keyDst_; + AscendC::GlobalTensor query_; + AscendC::GlobalTensor key_; + AscendC::GlobalTensor cosSin_; + int rotDim_; + int embedDim_; + int64_t queryStride_; + int64_t keyStride_; + int64_t dstQueryStride_; + int64_t dstKeyStride_; + int numHeads_; + int numKvHeads_; + int headSize_; + int calcTmpBufferOffset_; + int upcastInputBufferOffset_; + int upcastTempBufferOffset_; + int cosSinUpcastBufferOffset_; + int tempBufferSize_; +}; + +// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp +// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure +#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \ + extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \ + __gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \ + __gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \ + const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \ + const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \ + { \ + AscendC::TPipe pipe; \ + RotaryEmbedding op{}; \ + op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \ + queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \ + for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \ + op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \ + queryStride, keyStride, numHeads, numKvHeads, headSize, i); \ + op.compute_function(); \ + } \ + } + +#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \ + ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \ + ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false); + +// Declare all the kernel entry here +ROPE_CUSTOM_KERNEL_DECLARE(half) +ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t) + +namespace vllm_ascend { + +#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \ + if (isNeox) \ + rope_custom_true_##TYPE<<>>( \ + positions, queryDst, keyDst, reinterpret_cast(query), reinterpret_cast(key), \ + reinterpret_cast(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \ + numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \ + else \ + rope_custom_false_##TYPE<<>>( \ + positions, queryDst, keyDst, reinterpret_cast(query), reinterpret_cast(key), \ + reinterpret_cast(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \ + numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); + +// maximum number for runtime to launch a ascendc kernel. +// we use this to constrain the maximum number of block size +static const int64_t maxParallelSize = 65535; + +extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst, + void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim, + const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride, + const int64_t dstKeyStride, const int numHeads, const int numKvHeads, + const int headSize, const int64_t numTokens, const uint32_t loopCnt, + uint32_t aivNum) +{ + + int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize; + if (type == AscendType::FP16) { + ROTARY_EMBEDDING_KERNEL_CALL(half); + } else if (type == AscendType::BF16) { + ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t); + } else { + return; + } +} + +} // namespace vllm_ascend \ No newline at end of file diff --git a/csrc/kernels/types.h b/csrc/kernels/types.h new file mode 100644 index 000000000..7072e8c1c --- /dev/null +++ b/csrc/kernels/types.h @@ -0,0 +1,25 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace vllm_ascend { +enum struct AscendType { + FP16 = 0, + BF16 = 1, + FP32 = 2, +}; +} \ No newline at end of file diff --git a/csrc/kernels/utils.h b/csrc/kernels/utils.h new file mode 100644 index 000000000..8b8cf217c --- /dev/null +++ b/csrc/kernels/utils.h @@ -0,0 +1,49 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "kernel_type.h" +namespace vllm_ascend { + +template struct AccType; + +template <> struct AccType { + using type = float; +}; + +template <> struct AccType { + using type = half; +}; + +template <> struct AccType { + using type = float; +}; + +template <> struct AccType { + using type = int; +}; + +template +__aicore__ inline void local_mem_copy(AscendC::LocalTensor dst, AscendC::LocalTensor src, int size) +{ + constexpr int loadSize = 256 / sizeof(scalar_t); + int loopCnt = size / loadSize; + int tailSize = size % loadSize; + if (loopCnt) + AscendC::Copy(dst, src, loadSize, loopCnt, {1, 1, 8, 8}); + AscendC::Copy(dst[loopCnt * loadSize], src[loopCnt * loadSize], tailSize, 1, {1, 1, 8, 8}); +} +} // namespace vllm_ascend \ No newline at end of file diff --git a/csrc/ops.h b/csrc/ops.h new file mode 100644 index 000000000..4296796e5 --- /dev/null +++ b/csrc/ops.h @@ -0,0 +1,32 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include "kernels/types.h" + +namespace vllm_ascend { + extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst, + void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim, + const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride, + const int64_t dstKeyStride, const int numHeads, const int numKvHeads, + const int headSize, const int64_t numTokens, const uint32_t loopCnt, + uint32_t aivNum); +} \ No newline at end of file diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp new file mode 100644 index 000000000..a4dc3a322 --- /dev/null +++ b/csrc/torch_binding.cpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "tiling/platform/platform_ascendc.h" +#include "aclnn/opdev/platform.h" +#include "ops.h" +#include "utils.h" + +namespace vllm_ascend { + +void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key, + int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox) +{ + int32_t deviceId = 0; + int64_t num_tokens = positions.numel(); + int positions_ndim = positions.dim(); + TORCH_CHECK( + positions_ndim == 1 || positions_ndim == 2, + "positions must have shape [num_tokens] or [batch_size, seq_len]"); + if (positions_ndim == 1) { + TORCH_CHECK( + query.size(0) == positions.size(0) && key.size(0) == positions.size(0), + "query, key and positions must have the same number of tokens"); + } + if (positions_ndim == 2) { + TORCH_CHECK( + query.size(0) == positions.size(0) && + key.size(0) == positions.size(0) && + query.size(1) == positions.size(1) && + key.size(1) == positions.size(1), + "query, key and positions must have the same batch_size and seq_len"); + } + + int query_hidden_size = query.numel() / num_tokens; + int key_hidden_size = key.numel() / num_tokens; + TORCH_CHECK(query_hidden_size % head_size == 0); + TORCH_CHECK(key_hidden_size % head_size == 0); + + // Make sure query and key have consistent number of heads + int num_heads = query_hidden_size / head_size; + int num_kv_heads = key_hidden_size / head_size; + TORCH_CHECK(num_heads % num_kv_heads == 0); + + int rot_dim = cos_sin_cache.size(1); + int64_t *position_ids_ptr = positions.data_ptr(); + void *query_ptr = query.data_ptr(); + void *key_ptr = key.data_ptr(); + void *cos_sin_cache_ptr = cos_sin_cache.data_ptr(); + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + at::ScalarType scalar_type = query.scalar_type(); + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + at_npu::native::OpCommand cmd; + cmd.Name("rotary_embedding"); + cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr, + query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, + num_heads, num_kv_heads, head_size]() -> int { + auto dtype_num = get_dtype_from_torch(scalar_type); + fe::PlatFormInfos platform_infos; + int device_id = 0; + fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos); + uint32_t aivNum = platform_infos.GetCoreNumByType("aiv"); + uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum; + rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr, + key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum); + return 0; + }); + cmd.Run(); + return ; +} +} // namespace vllm_ascend + +TORCH_LIBRARY_EXPAND(_C, ops) +{ + // vLLM-Ascend custom ops + + // Rotary embedding + // Apply GPT-NeoX style rotary embedding to query and key. + ops.def( + "rotary_embedding(Tensor positions, Tensor! query," + " Tensor! key, int head_size," + " Tensor cos_sin_cache, bool is_neox) -> ()"); + ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding); +} + +REGISTER_EXTENSION(_C) diff --git a/csrc/utils.h b/csrc/utils.h new file mode 100644 index 000000000..e94ad2d84 --- /dev/null +++ b/csrc/utils.h @@ -0,0 +1,43 @@ +#pragma once + +#include "kernels/types.h" +#include +#include + +#define _CONCAT(A, B) A##B +#define CONCAT(A, B) _CONCAT(A, B) + +#define _STRINGIFY(A) #A +#define STRINGIFY(A) _STRINGIFY(A) + +// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) + +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + +// REGISTER_EXTENSION allows the shared library to be loaded and initialized +// via python's import statement. +#define REGISTER_EXTENSION(NAME) \ + PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \ + static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \ + STRINGIFY(NAME), nullptr, 0, nullptr}; \ + return PyModule_Create(&module); \ + } + + +namespace vllm_ascend { +AscendType get_dtype_from_torch(at::ScalarType scalarType) +{ + if (scalarType == at::ScalarType::Float) { + return AscendType::FP32; + } else if (scalarType == at::ScalarType::BFloat16) { + return AscendType::BF16; + } else { + return AscendType::FP16; + } +} +} // namespace vllm_ascend diff --git a/pyproject.toml b/pyproject.toml index 9210a2e25..f1e638b6f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,9 +3,12 @@ requires = [ "setuptools>=64", "setuptools-scm>=8", + "cmake>=3.26", + "pybind11", "decorator", "pyyaml", "scipy", - "torch-npu >= 2.5.1rc1" + "torch_npu >= 2.5.1rc1", + "torch >= 2.5.1" ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index 19642660a..c8f71ffe5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ decorator pyyaml scipy +pybind11 setuptools setuptools-scm numpy==1.26.4 diff --git a/setup.py b/setup.py index 492934ac9..0bcfd88a8 100644 --- a/setup.py +++ b/setup.py @@ -17,12 +17,265 @@ # limitations under the License. # +import importlib.util +import logging import os -from typing import List +import subprocess +import sys +from sysconfig import get_paths +from typing import Dict, List -from setuptools import find_packages, setup +from setuptools import Extension, find_packages, setup +from setuptools.command.build_ext import build_ext +from setuptools.command.develop import develop +from setuptools.command.install import install from setuptools_scm import get_version + +def load_module_from_path(module_name, path): + spec = importlib.util.spec_from_file_location(module_name, path) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +ROOT_DIR = os.path.dirname(__file__) +logger = logging.getLogger(__name__) + + +def check_or_set_default_env(cmake_args, + env_name, + env_variable, + default_path=""): + if env_variable is None: + logging.warning( + f"No {env_name} found in your environment, pleause try to set {env_name} " + "if you customize the installation path of this library, otherwise default " + "path will be adapted during build this project") + logging.warning(f"Set default {env_name}: {default_path}") + env_variable = default_path + else: + logging.info(f"Found existing {env_name}: {env_variable}") + # cann package seems will check this environments in cmake, need write this env variable back. + if env_name == "ASCEND_HOME_PATH": + os.environ["ASCEND_HOME_PATH"] = env_variable + cmake_args += [f"-D{env_name}={env_variable}"] + return cmake_args + + +envs = load_module_from_path("envs", + os.path.join(ROOT_DIR, "vllm_ascend", "envs.py")) + + +class CMakeExtension(Extension): + + def __init__(self, + name: str, + cmake_lists_dir: str = ".", + **kwargs) -> None: + super().__init__(name, sources=[], py_limited_api=False, **kwargs) + self.cmake_lists_dir = os.path.abspath(cmake_lists_dir) + + +class cmake_build_ext(build_ext): + # A dict of extension directories that have been configured. + did_config: Dict[str, bool] = {} + + # + # Determine number of compilation jobs + # + def compute_num_jobs(self): + # `num_jobs` is either the value of the MAX_JOBS environment variable + # (if defined) or the number of CPUs available. + num_jobs = envs.MAX_JOBS + if num_jobs is not None: + num_jobs = int(num_jobs) + logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs) + else: + try: + # os.sched_getaffinity() isn't universally available, so fall + # back to os.cpu_count() if we get an error here. + num_jobs = len(os.sched_getaffinity(0)) + except AttributeError: + num_jobs = os.cpu_count() + num_jobs = max(1, num_jobs) + + return num_jobs + + # + # Perform cmake configuration for a single extension. + # + def configure(self, ext: CMakeExtension) -> None: + build_temp = self.build_temp + os.makedirs(build_temp, exist_ok=True) + source_dir = os.path.abspath(ROOT_DIR) + python_executable = sys.executable + cmake_args = ["cmake"] + # Default use release mode to compile the csrc code + # Turbo now support compiled with Release, Debug and RelWithDebugInfo + if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [ + "Debug", + "Release", + "RelWithDebugInfo", + ]: + envs.CMAKE_BUILD_TYPE = "Release" + cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"] + # Default dump the compile commands for lsp + cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"] + if envs.VERBOSE: + cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"] + + # find ASCEND_HOME_PATH + check_or_set_default_env( + cmake_args, + "ASCEND_HOME_PATH", + envs.ASCEND_HOME_PATH, + "/usr/local/Ascend/ascend-toolkit/latest", + ) + + # find PYTHON_EXECUTABLE + check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE", + sys.executable) + + # find PYTHON_INCLUDE_PATH + check_or_set_default_env(cmake_args, "PYHTON_INCLUDE_PATH", + get_paths()["include"]) + + # ccache and ninja can not be applied at ascendc kernels now + + try: + # if pybind11 is installed via pip + pybind11_cmake_path = (subprocess.check_output( + [python_executable, "-m", "pybind11", + "--cmake"]).decode().strip()) + except subprocess.CalledProcessError as e: + # else specify pybind11 path installed from source code on CI container + raise RuntimeError(f"CMake configuration failed: {e}") + + # try retrive soc version from npu-smi + soc_command = [ + "bash", + "-c", + "npu-smi info | grep OK | awk '{print $3}' | head -n 1", + ] + try: + soc_version = subprocess.check_output(soc_command, + text=True).strip() + soc_version = soc_version.split("-")[0] + soc_version = "Ascend" + soc_version + except subprocess.CalledProcessError as e: + raise RuntimeError(f"Retrive Soc version failed: {e}") + + # add SOC_VERSION + cmake_args += [f"-DSOC_VERSION={soc_version}"] + + install_path = os.path.join(ROOT_DIR, self.build_lib) + if isinstance(self.distribution.get_command_obj("develop"), develop): + install_path = os.path.join(ROOT_DIR, "vllm_ascend") + # add CMAKE_INSTALL_PATH + cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"] + + cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"] + + # Override the base directory for FetchContent downloads to $ROOT/.deps + # This allows sharing dependencies between profiles, + # and plays more nicely with sccache. + # To override this, set the FETCHCONTENT_BASE_DIR environment variable. + fc_base_dir = os.path.join(ROOT_DIR, ".deps") + fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir) + cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)] + + build_tool = [] + # TODO(ganyi): ninja and ccache support for ascend c auto codegen. now we can only use make build + # if which('ninja') is not None: + # build_tool += ['-G', 'Ninja'] + # Default build tool to whatever cmake picks. + + cmake_args += [source_dir] + logging.info(f"cmake config command: {cmake_args}") + try: + subprocess.check_call(cmake_args, cwd=self.build_temp) + except subprocess.CalledProcessError as e: + raise RuntimeError(f"CMake configuration failed: {e}") + + subprocess.check_call( + ["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args], + cwd=self.build_temp, + ) + + def build_extensions(self) -> None: + if envs.COMPILE_CUSTOM_KERNELS is None: + return + # Ensure that CMake is present and working + try: + subprocess.check_output(["cmake", "--version"]) + except OSError as e: + raise RuntimeError(f"Cannot find CMake executable: {e}") + + # Create build directory if it does not exist. + if not os.path.exists(self.build_temp): + os.makedirs(self.build_temp) + + targets = [] + + os.makedirs(os.path.join(self.build_lib, "vllm_ascend"), exist_ok=True) + + def target_name(s: str) -> str: + return s.removeprefix("vllm_ascend.") + + # Build all the extensions + for ext in self.extensions: + self.configure(ext) + targets.append(target_name(ext.name)) + + num_jobs = self.compute_num_jobs() + + build_args = [ + "--build", + ".", + f"-j={num_jobs}", + *[f"--target={name}" for name in targets], + ] + try: + subprocess.check_call(["cmake", *build_args], cwd=self.build_temp) + except OSError as e: + raise RuntimeError(f"Build library failed: {e}") + # Install the libraries + install_args = [ + "cmake", + "--install", + ".", + ] + try: + subprocess.check_call(install_args, cwd=self.build_temp) + except OSError as e: + raise RuntimeError(f"Install library failed: {e}") + + # copy back to build folder for editable build + if isinstance(self.distribution.get_command_obj("develop"), develop): + import shutil + for root, _, files in os.walk(self.build_temp): + for file in files: + if file.endswith(".so"): + src_path = os.path.join(root, file) + dst_path = os.path.join(self.build_lib, "vllm_ascend", + file) + shutil.copy(src_path, dst_path) + print(f"Copy: {src_path} -> {dst_path}") + + def run(self): + # First, run the standard build_ext command to compile the extensions + super().run() + + +class custom_install(install): + + def run(self): + self.run_command("build_ext") + install.run(self) + + ROOT_DIR = os.path.dirname(__file__) try: VERSION = get_version(write_to="vllm_ascend/_version.py") @@ -31,6 +284,10 @@ except LookupError: # only checks out the commit. In this case, we set a dummy version. VERSION = "0.0.0" +ext_modules = [] +if envs.COMPILE_CUSTOM_KERNELS is not None: + ext_modules = [CMakeExtension(name="vllm_ascend.vllm_ascend_C")] + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) @@ -69,8 +326,10 @@ def get_requirements() -> List[str]: return requirements +cmdclass = {"build_ext": cmake_build_ext, "install": custom_install} + setup( - name='vllm_ascend', + name="vllm_ascend", # Follow: # https://packaging.python.org/en/latest/specifications/version-specifiers version=VERSION, @@ -95,12 +354,15 @@ setup( "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ], - packages=find_packages(exclude=("docs", "examples", "tests*")), + packages=find_packages(exclude=("docs", "examples", "tests*", "csrc")), python_requires=">=3.9", install_requires=get_requirements(), + ext_modules=ext_modules, + cmdclass=cmdclass, extras_require={}, entry_points={ - 'vllm.platform_plugins': ["ascend = vllm_ascend:register"], - 'vllm.general_plugins': - ["ascend_enhanced_model = vllm_ascend:register_model"] - }) + "vllm.platform_plugins": ["ascend = vllm_ascend:register"], + "vllm.general_plugins": + ["ascend_enhanced_model = vllm_ascend:register_model"], + }, +) diff --git a/tests/ops/test_rotary_embedding.py b/tests/ops/test_rotary_embedding.py new file mode 100644 index 000000000..5f4b3f916 --- /dev/null +++ b/tests/ops/test_rotary_embedding.py @@ -0,0 +1,204 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py +# Copyright 2023 The vLLM team. + +# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved. + +from typing import Optional, Tuple, Union + +import pytest +import torch +import torch.nn as nn +import torch_npu # noqa: F401 + +import vllm_ascend.platform # noqa: F401 + +# Only Neox style true scenario is supported for now +IS_NEOX_STYLE = [True] +DTYPES = [torch.half] +HEAD_SIZES = [64, 96, 128, 256] +ROTARY_DIMS = [None, 32] # None means rotary dim == head size +NUM_HEADS = [17] # Arbitrary values for testing +BATCH_SIZES = [5] # Arbitrary values for testing +SEQ_LENS = [11, 4096] # Arbitrary values for testing +SEEDS = [0] +DEVICES = [f"npu:{0}"] +# Set tolerance to 1 for quant ops +DEFAULT_ATOL = 1e-3 +DEFAULT_RTOL = 1e-3 + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + +# test with leading dimension and merge seqlen and batch_size as num_tokens +# TODO(ganyi): open this test in the future +@pytest.mark.skip( + reason= + "skip this test by default for now because of ci issue, will enable it in the future" +) +@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) +@pytest.mark.parametrize("batch_size", BATCH_SIZES) +@pytest.mark.parametrize("seq_len", SEQ_LENS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_rotary_embedding_quant_with_leading_dim( + is_neox_style: bool, + batch_size: int, + seq_len: int, + num_heads: int, + head_size: int, + rotary_dim: Optional[int], + dtype: torch.dtype, + seed: int, + device: str, + max_position: int = 8192, + base: int = 10000, +) -> None: + if rotary_dim is None: + rotary_dim = head_size + + torch.set_default_device(device) + if rotary_dim is None: + rotary_dim = head_size + rope = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + rope = rope.to(dtype=dtype) + num_tokens = batch_size * seq_len + positions = torch.randint(0, max_position, (batch_size * seq_len, )) + qkv_tensor = torch.randn(num_tokens, + num_heads * head_size * 3, + dtype=dtype) + query, key, _ = qkv_tensor.split( + [num_heads * head_size, num_heads * head_size, num_heads * head_size], + dim=-1, + ) + + ref_query, ref_key = rope.forward_native(positions, query, key) + torch.ops._C.rotary_embedding( + positions, + query, + key, + rope.head_size, + rope.cos_sin_cache, + rope.is_neox_style, + ) + + # Compare the results. + torch.testing.assert_close(query, + ref_query, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) + torch.testing.assert_close(key, + ref_key, + atol=DEFAULT_ATOL, + rtol=DEFAULT_RTOL) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py new file mode 100644 index 000000000..014bfd7d7 --- /dev/null +++ b/vllm_ascend/envs.py @@ -0,0 +1,25 @@ +import os +from typing import Any, Callable, Dict + +env_variables: Dict[str, Callable[[], Any]] = { + # max compile thread num + "MAX_JOBS": lambda: os.getenv("MAX_JOBS", None), + "CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"), + "COMPILE_CUSTOM_KERNELS": + lambda: os.getenv("COMPILE_CUSTOM_KERNELS", None), + # If set, vllm-ascend will print verbose logs during compliation + "VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))), + "ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None), + "LD_LIBRARY_PATH": lambda: os.getenv("LD_LIBRARY_PATH", None), +} + + +def __getattr__(name: str): + # lazy evaluation of environment variables + if name in env_variables: + return env_variables[name]() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__(): + return list(env_variables.keys()) diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b00c4b8df..326508e40 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -15,6 +15,7 @@ # limitations under the License. # +import logging import os from typing import TYPE_CHECKING, Optional, Tuple @@ -23,6 +24,19 @@ import torch_npu # noqa: F401 import vllm.envs as envs from vllm.config import CompilationLevel from vllm.logger import init_logger + +try: + # register custom ops into torch_library here + import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401 + +except ImportError as e: + if not str( + e + ) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)": + logging.warning( + "Warning: Failed to register custom ops, all custom ops will be disabled" + ) + from vllm.platforms import Platform, PlatformEnum if TYPE_CHECKING: