mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 05:33:51 +08:00
[core] Support custom ascendc kernels in vllm-ascend (#233)
This PR add custom ascendc kernel rotary_embedding support in vllm-ascend, related CMakeLists and setuptools is also added in this PR. Related: https://github.com/vllm-project/vllm-ascend/issues/156 --------- Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
1
.github/workflows/vllm_ascend_test_main.yaml
vendored
1
.github/workflows/vllm_ascend_test_main.yaml
vendored
@ -62,6 +62,7 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
apt-get -y install `cat packages.txt`
|
||||
apt-get -y install gcc g++ cmake libnuma-dev
|
||||
|
||||
- name: Checkout vllm-project/vllm repo
|
||||
uses: actions/checkout@v4
|
||||
|
102
CMakeLists.txt
Normal file
102
CMakeLists.txt
Normal file
@ -0,0 +1,102 @@
|
||||
cmake_minimum_required(VERSION 3.16)
|
||||
project(vllm_ascend_C)
|
||||
|
||||
# include(CheckCXXcompilerFlag)
|
||||
# check_cxx_compiler_flag("-std=c++17", COMPILER_SUPPORTS_CXX17)
|
||||
|
||||
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
|
||||
# Suppress potential warnings about unused manually-specified variables
|
||||
set(ignoreMe "${VLLM_PYTHON_PATH}")
|
||||
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
find_package(pybind11 REQUIRED)
|
||||
|
||||
append_cmake_prefix_path("torch" "torch.utils.cmake_prefix_path")
|
||||
set(VLLM_ASCEND_INSTALL_PATH "${CMAKE_INSTALL_PREFIX}")
|
||||
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
set(RUN_MODE "npu" CACHE STRING "cpu/sim/npu")
|
||||
message(STATUS "Detected SOC version: ${SOC_VERSION}")
|
||||
|
||||
if (NOT CMAKE_BUILD_TYPE)
|
||||
set(CMAKE_BUILD_TYPE "Release" CACHE STRINGS "Build type Release/Debug (default Release)" FORCE)
|
||||
endif()
|
||||
|
||||
if (CMAKE_INSTALL_PREFIX STREQUAL /usr/local)
|
||||
set(CMAKE_INSTALL_PREFIX "${CMAKE_CURRENT_LIST_DIR}/out" CACHE STRINGS "path to install()")
|
||||
endif()
|
||||
|
||||
set(ASCEND_CANN_PACKAGE_PATH ${ASCEND_HOME_PATH})
|
||||
if(EXISTS ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
|
||||
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/tools/tikcpp/ascendc_kernel_cmake)
|
||||
elseif(EXISTS ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
|
||||
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/compiler/tikcpp/ascendc_kernel_cmake)
|
||||
elseif(EXISTS ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
|
||||
set(ASCENDC_CMAKE_DIR ${ASCEND_HOME_PATH}/ascendc_devkit/tikcpp/samples/cmake)
|
||||
else()
|
||||
message(FATAL_ERROR "ascendc_kernel_cmake does not exist, please check whether the cann package is installed.")
|
||||
endif()
|
||||
|
||||
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
|
||||
file(GLOB KERNEL_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/pos_encoding_kernels.cpp)
|
||||
|
||||
ascendc_library(vllm_ascend_kernels SHARED
|
||||
${KERNEL_FILES}
|
||||
)
|
||||
|
||||
execute_process(COMMAND python3 -c "import os; import torch_npu; print(os.path.dirname(torch_npu.__file__))"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
OUTPUT_VARIABLE TORCH_NPU_PATH
|
||||
)
|
||||
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
|
||||
|
||||
file(GLOB VLLM_ASCEND_SRC
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp)
|
||||
|
||||
include_directories(
|
||||
${pybind11_INCLUDE_DIRS}
|
||||
${PYTHON_INCLUDE_PATH}
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_NPU_PATH}/include
|
||||
${ASCEND_HOME_PATH}/include
|
||||
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
|
||||
${ASCEND_HOME_PATH}/x86_64-linux/include/experiment/platform
|
||||
)
|
||||
|
||||
set(
|
||||
INCLUDES
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_NPU_INCLUDE_DIRS}
|
||||
${ASCEND_HOME_PATH}/include
|
||||
${ASCEND_HOME_PATH}/aarch64-linux/include/experiment/platform
|
||||
)
|
||||
|
||||
pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})
|
||||
|
||||
target_link_directories(
|
||||
vllm_ascend_C
|
||||
PRIVATE
|
||||
${TORCH_NPU_PATH}/lib/
|
||||
${ASCEND_HOME_PATH}/lib64
|
||||
)
|
||||
|
||||
target_link_libraries(
|
||||
vllm_ascend_C
|
||||
PUBLIC
|
||||
${TORCH_LIBRARIES}
|
||||
libtorch_npu.so
|
||||
vllm_ascend_kernels
|
||||
ascendcl
|
||||
platform
|
||||
)
|
||||
|
||||
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||
|
||||
install(TARGETS vllm_ascend_C vllm_ascend_kernels DESTINATION ${VLLM_ASCEND_INSTALL_PATH})
|
||||
|
||||
|
133
cmake/utils.cmake
Normal file
133
cmake/utils.cmake
Normal file
@ -0,0 +1,133 @@
|
||||
#
|
||||
# Attempt to find the python package that uses the same python executable as
|
||||
# `EXECUTABLE` and is one of the `SUPPORTED_VERSIONS`.
|
||||
#
|
||||
macro (find_python_from_executable EXECUTABLE SUPPORTED_VERSIONS)
|
||||
file(REAL_PATH ${EXECUTABLE} EXECUTABLE)
|
||||
set(Python_EXECUTABLE ${EXECUTABLE})
|
||||
find_package(Python COMPONENTS Interpreter Development.Module Development.SABIModule)
|
||||
if (NOT Python_FOUND)
|
||||
message(FATAL_ERROR "Unable to find python matching: ${EXECUTABLE}.")
|
||||
endif()
|
||||
set(_VER "${Python_VERSION_MAJOR}.${Python_VERSION_MINOR}")
|
||||
set(_SUPPORTED_VERSIONS_LIST ${SUPPORTED_VERSIONS} ${ARGN})
|
||||
if (NOT _VER IN_LIST _SUPPORTED_VERSIONS_LIST)
|
||||
message(FATAL_ERROR
|
||||
"Python version (${_VER}) is not one of the supported versions: "
|
||||
"${_SUPPORTED_VERSIONS_LIST}.")
|
||||
endif()
|
||||
message(STATUS "Found python matching: ${EXECUTABLE}.")
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# Run `EXPR` in python. The standard output of python is stored in `OUT` and
|
||||
# has trailing whitespace stripped. If an error is encountered when running
|
||||
# python, a fatal message `ERR_MSG` is issued.
|
||||
#
|
||||
function (run_python OUT EXPR ERR_MSG)
|
||||
execute_process(
|
||||
COMMAND
|
||||
"${PYTHON_EXECUTABLE}" "-c" "${EXPR}"
|
||||
OUTPUT_VARIABLE PYTHON_OUT
|
||||
RESULT_VARIABLE PYTHON_ERROR_CODE
|
||||
ERROR_VARIABLE PYTHON_STDERR
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE)
|
||||
|
||||
if(NOT PYTHON_ERROR_CODE EQUAL 0)
|
||||
message(FATAL_ERROR "${ERR_MSG}: ${PYTHON_STDERR}")
|
||||
endif()
|
||||
set(${OUT} ${PYTHON_OUT} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
# Run `EXPR` in python after importing `PKG`. Use the result of this to extend
|
||||
# `CMAKE_PREFIX_PATH` so the torch cmake configuration can be imported.
|
||||
macro (append_cmake_prefix_path PKG EXPR)
|
||||
run_python(_PREFIX_PATH
|
||||
"import ${PKG}; print(${EXPR})" "Failed to locate ${PKG} path")
|
||||
list(APPEND CMAKE_PREFIX_PATH ${_PREFIX_PATH})
|
||||
endmacro()
|
||||
|
||||
|
||||
# This cmake function is adapted from vllm /Users/ganyi/workspace/vllm-ascend/cmake/utils.cmake
|
||||
# Define a target named `GPU_MOD_NAME` for a single extension. The
|
||||
# arguments are:
|
||||
#
|
||||
# DESTINATION <dest> - Module destination directory.
|
||||
# LANGUAGE <lang> - The GPU language for this module, e.g CUDA, HIP,
|
||||
# etc.
|
||||
# SOURCES <sources> - List of source files relative to CMakeLists.txt
|
||||
# directory.
|
||||
#
|
||||
# Optional arguments:
|
||||
#
|
||||
# ARCHITECTURES <arches> - A list of target GPU architectures in cmake
|
||||
# format.
|
||||
# Refer `CMAKE_CUDA_ARCHITECTURES` documentation
|
||||
# and `CMAKE_HIP_ARCHITECTURES` for more info.
|
||||
# ARCHITECTURES will use cmake's defaults if
|
||||
# not provided.
|
||||
# COMPILE_FLAGS <flags> - Extra compiler flags passed to NVCC/hip.
|
||||
# INCLUDE_DIRECTORIES <dirs> - Extra include directories.
|
||||
# LIBRARIES <libraries> - Extra link libraries.
|
||||
# WITH_SOABI - Generate library with python SOABI suffix name.
|
||||
# USE_SABI <version> - Use python stable api <version>
|
||||
#
|
||||
# Note: optimization level/debug info is set via cmake build type.
|
||||
#
|
||||
function (define_gpu_extension_target GPU_MOD_NAME)
|
||||
cmake_parse_arguments(PARSE_ARGV 1
|
||||
GPU
|
||||
"WITH_SOABI"
|
||||
"DESTINATION;LANGUAGE;USE_SABI"
|
||||
"SOURCES;ARCHITECTURES;COMPILE_FLAGS;INCLUDE_DIRECTORIES;LIBRARIES")
|
||||
|
||||
# Add hipify preprocessing step when building with HIP/ROCm.
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
hipify_sources_target(GPU_SOURCES ${GPU_MOD_NAME} "${GPU_SOURCES}")
|
||||
endif()
|
||||
|
||||
if (GPU_WITH_SOABI)
|
||||
set(GPU_WITH_SOABI WITH_SOABI)
|
||||
else()
|
||||
set(GPU_WITH_SOABI)
|
||||
endif()
|
||||
|
||||
if (GPU_USE_SABI)
|
||||
Python_add_library(${GPU_MOD_NAME} MODULE USE_SABI ${GPU_USE_SABI} ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||
else()
|
||||
Python_add_library(${GPU_MOD_NAME} MODULE ${GPU_WITH_SOABI} "${GPU_SOURCES}")
|
||||
endif()
|
||||
|
||||
if (GPU_LANGUAGE STREQUAL "HIP")
|
||||
# Make this target dependent on the hipify preprocessor step.
|
||||
add_dependencies(${GPU_MOD_NAME} hipify${GPU_MOD_NAME})
|
||||
endif()
|
||||
|
||||
if (GPU_ARCHITECTURES)
|
||||
set_target_properties(${GPU_MOD_NAME} PROPERTIES
|
||||
${GPU_LANGUAGE}_ARCHITECTURES "${GPU_ARCHITECTURES}")
|
||||
endif()
|
||||
|
||||
set_property(TARGET ${GPU_MOD_NAME} PROPERTY CXX_STANDARD 17)
|
||||
|
||||
target_compile_options(${GPU_MOD_NAME} PRIVATE
|
||||
$<$<COMPILE_LANGUAGE:${GPU_LANGUAGE}>:${GPU_COMPILE_FLAGS}>)
|
||||
|
||||
target_compile_definitions(${GPU_MOD_NAME} PRIVATE
|
||||
"-DTORCH_EXTENSION_NAME=${GPU_MOD_NAME}")
|
||||
|
||||
target_include_directories(${GPU_MOD_NAME} PRIVATE csrc
|
||||
${GPU_INCLUDE_DIRECTORIES})
|
||||
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES})
|
||||
|
||||
# Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of
|
||||
# dependencies that are not necessary and may not be installed.
|
||||
if (GPU_LANGUAGE STREQUAL "CUDA")
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE CUDA::cudart CUDA::cuda_driver)
|
||||
else()
|
||||
target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES})
|
||||
endif()
|
||||
|
||||
install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME})
|
||||
endfunction()
|
367
csrc/kernels/pos_encoding_kernels.cpp
Normal file
367
csrc/kernels/pos_encoding_kernels.cpp
Normal file
@ -0,0 +1,367 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tpipe_impl.h"
|
||||
#include "kernel_tensor_impl.h"
|
||||
#include "kernel_type.h"
|
||||
#include "kernel_operator_intf.h"
|
||||
#include "inner_interface/inner_kernel_operator_intf.h"
|
||||
#include <stdio.h>
|
||||
#include "types.h"
|
||||
#include "utils.h"
|
||||
|
||||
|
||||
using vllm_ascend::AccType;
|
||||
using vllm_ascend::local_mem_copy;
|
||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||
// NOTE(ganyi): we use 32K as load stride for pipe, need to find another way to
|
||||
// retrive this size from runtime for more Soc support
|
||||
static int constexpr loadSize = 1024 * 4;
|
||||
using dst_t = scalar_t;
|
||||
using acc_t = typename AccType<scalar_t>::type;
|
||||
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
||||
using local_scalar_t = AscendC::LocalTensor<scalar_t>;
|
||||
using local_acc_t = AscendC::LocalTensor<acc_t>;
|
||||
using local_dst_t = AscendC::LocalTensor<dst_t>;
|
||||
|
||||
public:
|
||||
__aicore__ inline RotaryEmbedding()
|
||||
{
|
||||
}
|
||||
|
||||
// Allocate buffers for input and output queue and the temp buffer used during kernel compute process,
|
||||
// this init process happens only in the kernel compute on a single vector core.
|
||||
__aicore__ inline void init(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
|
||||
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
|
||||
const int rotDim, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int64_t queryStride, const int64_t keyStride,
|
||||
const int numHeads, const int numKvHeads, const int headSize, AscendC::TPipe *pipe)
|
||||
{
|
||||
pipe_ = pipe;
|
||||
rotDim_ = rotDim;
|
||||
// query stride and key stride is used to handle the strided tensor which is not contiguous on num_tokens dim
|
||||
queryStride_ = queryStride;
|
||||
keyStride_ = keyStride;
|
||||
dstQueryStride_ = dstQueryStride;
|
||||
dstKeyStride_ = dstKeyStride;
|
||||
numHeads_ = numHeads;
|
||||
numKvHeads_ = numKvHeads;
|
||||
headSize_ = headSize;
|
||||
embedDim_ = rotDim / 2;
|
||||
|
||||
pipe_->InitBuffer(inQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||
pipe_->InitBuffer(inQueSinCos_, 1 /* buffer_num */, rotDim_ * sizeof(scalar_t) /* buffer_size */);
|
||||
pipe_->InitBuffer(outQue_, 1 /* buffer_num */, loadSize /* buffer_size */);
|
||||
// 2 temperary calculation buffer
|
||||
calcTmpBufferOffset_ = 0;
|
||||
// 1 upcast buffer for bf16 (headSize)
|
||||
upcastInputBufferOffset_ = calcTmpBufferOffset_ + sizeof(acc_t) * embedDim_ * 2;
|
||||
// 1 upcast temp buffer for bf16 (2 * embed_dim)
|
||||
upcastTempBufferOffset_ = upcastInputBufferOffset_ + sizeof(acc_t) * headSize_;
|
||||
// 2 sin cos upcast buffer for bf16
|
||||
cosSinUpcastBufferOffset_ = upcastTempBufferOffset_ + sizeof(acc_t) * 2 * embedDim_;
|
||||
// 2. bf16 path: needs 2 cos sin upcast buffer size
|
||||
// 3. fp16 path: needs 2 temperary calculation buffer size
|
||||
tempBufferSize_ = cosSinUpcastBufferOffset_ + 2 * embedDim_ * sizeof(acc_t);
|
||||
// need to consider upcast the bf16 to fp32, so we might need 4 buffer just in case
|
||||
// 2 temperary buffer, 2 input buffer, 1 cos buffer, 1 sin buffer, 2 scale buffer (headSize), 2 zp
|
||||
// buffer(headSize int8), 1 dst_temp buffer(headSize, int32)
|
||||
pipe_->InitBuffer(calcBuf_, tempBufferSize_ /* buffer_size */);
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
pipe_->InitBuffer(copyBuf_, loadSize);
|
||||
}
|
||||
}
|
||||
__aicore__ inline void update_mem_offset(__gm__ int64_t *positions, __gm__ void *queryDst, __gm__ void *keyDst,
|
||||
__gm__ scalar_t *query, __gm__ scalar_t *key, __gm__ scalar_t *cosSinCache,
|
||||
const int rotDim, const int64_t dstQueryStride, const int64_t dstKeyStride,
|
||||
const int64_t queryStride, const int64_t keyStride, const int numHeads,
|
||||
const int numKvHeads, const int headSize, const int64_t idx)
|
||||
{
|
||||
int64_t pos = positions[idx];
|
||||
cosSin_.SetGlobalBuffer(cosSinCache + pos * rotDim_, rotDim_);
|
||||
query_.SetGlobalBuffer(query + queryStride * idx, headSize * numHeads_);
|
||||
key_.SetGlobalBuffer(key + keyStride * idx, headSize * numKvHeads_);
|
||||
queryDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(queryDst) + dstQueryStride * idx,
|
||||
headSize * numHeads_);
|
||||
keyDst_.SetGlobalBuffer(reinterpret_cast<__gm__ dst_t *>(keyDst) + dstKeyStride * idx, headSize * numKvHeads_);
|
||||
}
|
||||
|
||||
// compute per head for neox on bf16
|
||||
template <typename acc_t_, typename std::enable_if<!std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
|
||||
__aicore__ inline void
|
||||
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
|
||||
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
|
||||
{
|
||||
// slice dst
|
||||
local_dst_t dstX = dst;
|
||||
local_dst_t dstY = dst[embedDim_];
|
||||
|
||||
// slice src
|
||||
local_scalar_t srcX = src;
|
||||
local_scalar_t srcY = src[embedDim_];
|
||||
|
||||
// slice temp buffer
|
||||
local_acc_t calcTmpBufferX = calcTmpBuffer;
|
||||
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
|
||||
|
||||
// slice upcast input buffer
|
||||
local_acc_t upcastBufferX = upcastInputBuffer;
|
||||
local_acc_t upcastBufferY = upcastBufferX[embedDim_];
|
||||
|
||||
// dst x calc
|
||||
Cast(upcastInputBuffer, src, AscendC::RoundMode::CAST_NONE, headSize_);
|
||||
Mul(calcTmpBufferX, upcastBufferX, cos, embedDim_);
|
||||
Mul(calcTmpBufferY, upcastBufferY, sin, embedDim_);
|
||||
Sub(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
Cast(dstX, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
|
||||
|
||||
// dst y calc
|
||||
Mul(calcTmpBufferX, upcastBufferX, sin, embedDim_);
|
||||
Mul(calcTmpBufferY, upcastBufferY, cos, embedDim_);
|
||||
Add(calcTmpBufferX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
Cast(dstY, calcTmpBufferX, AscendC::RoundMode::CAST_TRUNC, embedDim_);
|
||||
}
|
||||
|
||||
// compute per head output for neox
|
||||
template <typename acc_t_, typename std::enable_if<std::is_same_v<acc_t_, scalar_t>, void>::type * = nullptr>
|
||||
__aicore__ inline void
|
||||
neox_compute(local_scalar_t src, local_dst_t dst, AscendC::LocalTensor<acc_t_> sin, AscendC::LocalTensor<acc_t_> cos,
|
||||
AscendC::LocalTensor<acc_t_> upcastInputBuffer, AscendC::LocalTensor<acc_t_> calcTmpBuffer)
|
||||
{
|
||||
// slice dst buffer
|
||||
local_dst_t dstX = dst;
|
||||
local_dst_t dstY = dst[embedDim_];
|
||||
// slice src buffer
|
||||
local_scalar_t srcX = src;
|
||||
local_scalar_t srcY = src[embedDim_];
|
||||
// slice temp buffer
|
||||
local_acc_t calcTmpBufferX = calcTmpBuffer;
|
||||
local_acc_t calcTmpBufferY = calcTmpBuffer[embedDim_];
|
||||
|
||||
// dst x calc
|
||||
Mul(calcTmpBufferX, srcX, cos, embedDim_);
|
||||
Mul(calcTmpBufferY, srcY, sin, embedDim_);
|
||||
Sub(dstX, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
|
||||
// dst y calc
|
||||
Mul(calcTmpBufferX, srcX, sin, embedDim_);
|
||||
Mul(calcTmpBufferY, srcY, cos, embedDim_);
|
||||
Add(dstY, calcTmpBufferX, calcTmpBufferY, embedDim_);
|
||||
}
|
||||
|
||||
__aicore__ inline void compute_qk(AscendC::GlobalTensor<scalar_t> srcG, AscendC::GlobalTensor<dst_t> dstG,
|
||||
local_acc_t localCos, local_acc_t localSin, local_acc_t upcastInputBuffer,
|
||||
local_acc_t calcTmpBuffer, int loopCnt, int tailHeads, int loadStride,
|
||||
int headNumPerLoad)
|
||||
{
|
||||
for (int loopNum = 0; loopNum < loopCnt; ++loopNum) {
|
||||
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
|
||||
local_dst_t dst = outQue_.AllocTensor<dst_t>();
|
||||
AscendC::DataCopy(src, srcG[loopNum * loadStride], loadStride);
|
||||
inQue_.EnQue(src);
|
||||
|
||||
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
int elem_num = loadStride / sizeof(scalar_t);
|
||||
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
|
||||
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
} else {
|
||||
local_mem_copy(dst, srcDeque, loadStride);
|
||||
}
|
||||
for (int i = 0; i < headNumPerLoad; ++i) {
|
||||
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
|
||||
calcTmpBuffer);
|
||||
}
|
||||
outQue_.EnQue(dst);
|
||||
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
|
||||
AscendC::DataCopy(dstG[loopNum * loadStride], dstDeque, loadStride);
|
||||
outQue_.FreeTensor(dstDeque);
|
||||
inQue_.FreeTensor(srcDeque);
|
||||
}
|
||||
// process tail
|
||||
{
|
||||
local_scalar_t src = inQue_.AllocTensor<scalar_t>();
|
||||
local_dst_t dst = outQue_.AllocTensor<dst_t>();
|
||||
|
||||
AscendC::DataCopy(src, srcG[loopCnt * loadStride], tailHeads * headSize_);
|
||||
inQue_.EnQue(src);
|
||||
local_scalar_t srcDeque = inQue_.DeQue<scalar_t>();
|
||||
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
int elem_num = tailHeads * headSize_ / sizeof(scalar_t);
|
||||
AscendC::LocalTensor<acc_t> upBuffer = copyBuf_.GetWithOffset<acc_t>(elem_num, 0);
|
||||
Cast(upBuffer, srcDeque, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
Cast(dst, upBuffer, AscendC::RoundMode::CAST_TRUNC, elem_num);
|
||||
} else {
|
||||
local_mem_copy(dst, srcDeque, tailHeads * headSize_);
|
||||
}
|
||||
|
||||
for (int i = 0; i < tailHeads; ++i) {
|
||||
neox_compute(srcDeque[i * headSize_], dst[i * headSize_], localSin, localCos, upcastInputBuffer,
|
||||
calcTmpBuffer);
|
||||
}
|
||||
outQue_.EnQue(dst);
|
||||
local_dst_t dstDeque = outQue_.DeQue<dst_t>();
|
||||
AscendC::DataCopy(dstG[loopCnt * loadStride], dstDeque, tailHeads * headSize_);
|
||||
outQue_.FreeTensor(dstDeque);
|
||||
inQue_.FreeTensor(srcDeque);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void compute_function()
|
||||
{
|
||||
local_scalar_t cosSinLocal = inQueSinCos_.AllocTensor<scalar_t>();
|
||||
|
||||
AscendC::DataCopy(cosSinLocal, cosSin_, embedDim_ * 2);
|
||||
|
||||
inQueSinCos_.EnQue(cosSinLocal);
|
||||
local_scalar_t localSinCosDeque = inQueSinCos_.DeQue<scalar_t>();
|
||||
local_scalar_t localCos = localSinCosDeque;
|
||||
local_scalar_t localSin = localSinCosDeque[embedDim_];
|
||||
|
||||
local_acc_t calcTmpBuffer;
|
||||
local_acc_t upcastInputBuffer;
|
||||
local_acc_t upcastTempBuffer;
|
||||
local_acc_t cosSinUpcastBuffer;
|
||||
local_acc_t scaleBuffer;
|
||||
local_acc_t offsetBuffer;
|
||||
calcTmpBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, calcTmpBufferOffset_);
|
||||
upcastInputBuffer = calcBuf_.GetWithOffset<acc_t>(headSize_, upcastInputBufferOffset_);
|
||||
upcastTempBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, upcastTempBufferOffset_);
|
||||
cosSinUpcastBuffer = calcBuf_.GetWithOffset<acc_t>(embedDim_ * 2, cosSinUpcastBufferOffset_);
|
||||
|
||||
local_acc_t cosAccBuffer;
|
||||
local_acc_t sinAccBuffer;
|
||||
|
||||
if constexpr (!std::is_same_v<scalar_t, acc_t>) {
|
||||
Cast(cosSinUpcastBuffer, localSinCosDeque, AscendC::RoundMode::CAST_NONE, 2 * embedDim_);
|
||||
cosAccBuffer = cosSinUpcastBuffer;
|
||||
sinAccBuffer = cosSinUpcastBuffer[embedDim_];
|
||||
} else {
|
||||
cosAccBuffer = localCos;
|
||||
sinAccBuffer = localSin;
|
||||
}
|
||||
|
||||
constexpr const int loadSizeByElem = loadSize / sizeof(scalar_t);
|
||||
int64_t headNumPerLoad = loadSizeByElem / headSize_;
|
||||
int64_t loopCnt = numHeads_ / headNumPerLoad;
|
||||
int64_t tailHeads = numHeads_ - loopCnt * headNumPerLoad;
|
||||
int64_t loadStride = headNumPerLoad * headSize_;
|
||||
int64_t loopCntKv = numKvHeads_ / headNumPerLoad;
|
||||
int64_t tailHeadsKv = numKvHeads_ - loopCntKv * headNumPerLoad;
|
||||
compute_qk(query_, queryDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer,
|
||||
calcTmpBuffer, loopCnt, tailHeads, loadStride, headNumPerLoad);
|
||||
|
||||
compute_qk(key_, keyDst_, cosAccBuffer, sinAccBuffer, upcastInputBuffer, calcTmpBuffer,
|
||||
loopCntKv, tailHeadsKv, loadStride, headNumPerLoad);
|
||||
|
||||
inQueSinCos_.FreeTensor(localSinCosDeque);
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::TPipe *pipe_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQue_, inQueSinCos_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQue_;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> calcBuf_;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> copyBuf_;
|
||||
AscendC::GlobalTensor<dst_t> queryDst_;
|
||||
AscendC::GlobalTensor<dst_t> keyDst_;
|
||||
AscendC::GlobalTensor<scalar_t> query_;
|
||||
AscendC::GlobalTensor<scalar_t> key_;
|
||||
AscendC::GlobalTensor<scalar_t> cosSin_;
|
||||
int rotDim_;
|
||||
int embedDim_;
|
||||
int64_t queryStride_;
|
||||
int64_t keyStride_;
|
||||
int64_t dstQueryStride_;
|
||||
int64_t dstKeyStride_;
|
||||
int numHeads_;
|
||||
int numKvHeads_;
|
||||
int headSize_;
|
||||
int calcTmpBufferOffset_;
|
||||
int upcastInputBufferOffset_;
|
||||
int upcastTempBufferOffset_;
|
||||
int cosSinUpcastBufferOffset_;
|
||||
int tempBufferSize_;
|
||||
};
|
||||
|
||||
// Note: Need to use macro to instaniate all the target functions here, for the current build system dose not support template call in cpp
|
||||
// We use C style symbol here for kernel compilation, cpp style kernel entry may lead to compilation failure
|
||||
#define ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, NEOX) \
|
||||
extern "C" __global__ __aicore__ void rope_custom_##NEOX##_##TYPE( \
|
||||
__gm__ int64_t* positions, __gm__ void* queryDst, __gm__ void* keyDst, __gm__ TYPE* query, __gm__ TYPE* key, \
|
||||
__gm__ TYPE* cosSinCache, const int rotDim, const int64_t queryStride, const int64_t keyStride, \
|
||||
const int64_t dstQueryStride, const int64_t dstKeyStride, const int numHeads, const int numKvHeads, \
|
||||
const int headSize, const int64_t numTokens, const int loopNum, const int coreNum) \
|
||||
{ \
|
||||
AscendC::TPipe pipe; \
|
||||
RotaryEmbedding<TYPE, NEOX> op{}; \
|
||||
op.init(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
|
||||
queryStride, keyStride, numHeads, numKvHeads, headSize, &pipe); \
|
||||
for (int64_t i = AscendC::GetBlockIdx(); i < numTokens; i += coreNum) { \
|
||||
op.update_mem_offset(positions, queryDst, keyDst, query, key, cosSinCache, rotDim, dstQueryStride, dstKeyStride, \
|
||||
queryStride, keyStride, numHeads, numKvHeads, headSize, i); \
|
||||
op.compute_function(); \
|
||||
} \
|
||||
}
|
||||
|
||||
#define ROPE_CUSTOM_KERNEL_DECLARE(TYPE) \
|
||||
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, true); \
|
||||
ROPE_CUSTOM_KERNEL_TYPE_DECLARE(TYPE, false);
|
||||
|
||||
// Declare all the kernel entry here
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(half)
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
#define ROTARY_EMBEDDING_KERNEL_CALL(TYPE) \
|
||||
if (isNeox) \
|
||||
rope_custom_true_##TYPE<<<blockDim, nullptr, stream>>>( \
|
||||
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
|
||||
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
|
||||
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim); \
|
||||
else \
|
||||
rope_custom_false_##TYPE<<<blockDim, nullptr, stream>>>( \
|
||||
positions, queryDst, keyDst, reinterpret_cast<TYPE *>(query), reinterpret_cast<TYPE *>(key), \
|
||||
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
|
||||
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
|
||||
|
||||
// maximum number for runtime to launch a ascendc kernel.
|
||||
// we use this to constrain the maximum number of block size
|
||||
static const int64_t maxParallelSize = 65535;
|
||||
|
||||
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
|
||||
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
|
||||
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum)
|
||||
{
|
||||
|
||||
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
|
||||
if (type == AscendType::FP16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(half);
|
||||
} else if (type == AscendType::BF16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
25
csrc/kernels/types.h
Normal file
25
csrc/kernels/types.h
Normal file
@ -0,0 +1,25 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace vllm_ascend {
|
||||
enum struct AscendType {
|
||||
FP16 = 0,
|
||||
BF16 = 1,
|
||||
FP32 = 2,
|
||||
};
|
||||
}
|
49
csrc/kernels/utils.h
Normal file
49
csrc/kernels/utils.h
Normal file
@ -0,0 +1,49 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "kernel_type.h"
|
||||
namespace vllm_ascend {
|
||||
|
||||
template <typename scalar_t> struct AccType;
|
||||
|
||||
template <> struct AccType<bfloat16_t> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <> struct AccType<half> {
|
||||
using type = half;
|
||||
};
|
||||
|
||||
template <> struct AccType<float> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <> struct AccType<int8_t> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
template <typename scalar_t>
|
||||
__aicore__ inline void local_mem_copy(AscendC::LocalTensor<scalar_t> dst, AscendC::LocalTensor<scalar_t> src, int size)
|
||||
{
|
||||
constexpr int loadSize = 256 / sizeof(scalar_t);
|
||||
int loopCnt = size / loadSize;
|
||||
int tailSize = size % loadSize;
|
||||
if (loopCnt)
|
||||
AscendC::Copy(dst, src, loadSize, loopCnt, {1, 1, 8, 8});
|
||||
AscendC::Copy(dst[loopCnt * loadSize], src[loopCnt * loadSize], tailSize, 1, {1, 1, 8, 8});
|
||||
}
|
||||
} // namespace vllm_ascend
|
32
csrc/ops.h
Normal file
32
csrc/ops.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <vector>
|
||||
#include "kernels/types.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, int64_t *positions, void *queryDst,
|
||||
void *keyDst, void *query, void *key, void *cosSinCache, const int rotDim,
|
||||
const int64_t queryStride, const int64_t keyStride, const int64_t dstQueryStride,
|
||||
const int64_t dstKeyStride, const int numHeads, const int numKvHeads,
|
||||
const int headSize, const int64_t numTokens, const uint32_t loopCnt,
|
||||
uint32_t aivNum);
|
||||
}
|
108
csrc/torch_binding.cpp
Normal file
108
csrc/torch_binding.cpp
Normal file
@ -0,0 +1,108 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
#include <torch/version.h>
|
||||
#include <torch_npu/csrc/core/npu/NPUStream.h>
|
||||
#include <torch_npu/csrc/framework/OpCommand.h>
|
||||
#include <torch_npu/csrc/npu/Module.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
#include "acl/acl.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "aclnn/opdev/platform.h"
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
void rotary_embedding(at::Tensor &positions, at::Tensor &query, at::Tensor &key,
|
||||
int64_t head_size, at::Tensor &cos_sin_cache, bool is_neox)
|
||||
{
|
||||
int32_t deviceId = 0;
|
||||
int64_t num_tokens = positions.numel();
|
||||
int positions_ndim = positions.dim();
|
||||
TORCH_CHECK(
|
||||
positions_ndim == 1 || positions_ndim == 2,
|
||||
"positions must have shape [num_tokens] or [batch_size, seq_len]");
|
||||
if (positions_ndim == 1) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) && key.size(0) == positions.size(0),
|
||||
"query, key and positions must have the same number of tokens");
|
||||
}
|
||||
if (positions_ndim == 2) {
|
||||
TORCH_CHECK(
|
||||
query.size(0) == positions.size(0) &&
|
||||
key.size(0) == positions.size(0) &&
|
||||
query.size(1) == positions.size(1) &&
|
||||
key.size(1) == positions.size(1),
|
||||
"query, key and positions must have the same batch_size and seq_len");
|
||||
}
|
||||
|
||||
int query_hidden_size = query.numel() / num_tokens;
|
||||
int key_hidden_size = key.numel() / num_tokens;
|
||||
TORCH_CHECK(query_hidden_size % head_size == 0);
|
||||
TORCH_CHECK(key_hidden_size % head_size == 0);
|
||||
|
||||
// Make sure query and key have consistent number of heads
|
||||
int num_heads = query_hidden_size / head_size;
|
||||
int num_kv_heads = key_hidden_size / head_size;
|
||||
TORCH_CHECK(num_heads % num_kv_heads == 0);
|
||||
|
||||
int rot_dim = cos_sin_cache.size(1);
|
||||
int64_t *position_ids_ptr = positions.data_ptr<int64_t>();
|
||||
void *query_ptr = query.data_ptr();
|
||||
void *key_ptr = key.data_ptr();
|
||||
void *cos_sin_cache_ptr = cos_sin_cache.data_ptr();
|
||||
int64_t query_stride = query.stride(-2);
|
||||
int64_t key_stride = key.stride(-2);
|
||||
at::ScalarType scalar_type = query.scalar_type();
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("rotary_embedding");
|
||||
cmd.SetCustomHandler([scalar_type, is_neox, num_tokens, stream, position_ids_ptr,
|
||||
query_ptr, key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride,
|
||||
num_heads, num_kv_heads, head_size]() -> int {
|
||||
auto dtype_num = get_dtype_from_torch(scalar_type);
|
||||
fe::PlatFormInfos platform_infos;
|
||||
int device_id = 0;
|
||||
fe::PlatformInfoManager::GeInstance().GetRuntimePlatformInfosByDevice(device_id, platform_infos);
|
||||
uint32_t aivNum = platform_infos.GetCoreNumByType("aiv");
|
||||
uint32_t loop_cnt = (num_tokens + aivNum - 1) / aivNum;
|
||||
rotary_embedding_impl(dtype_num, is_neox, stream, position_ids_ptr, query_ptr, key_ptr, query_ptr,
|
||||
key_ptr, cos_sin_cache_ptr, rot_dim, query_stride, key_stride, query_stride,
|
||||
key_stride, num_heads, num_kv_heads, head_size, num_tokens, loop_cnt, aivNum);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return ;
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
{
|
||||
// vLLM-Ascend custom ops
|
||||
|
||||
// Rotary embedding
|
||||
// Apply GPT-NeoX style rotary embedding to query and key.
|
||||
ops.def(
|
||||
"rotary_embedding(Tensor positions, Tensor! query,"
|
||||
" Tensor! key, int head_size,"
|
||||
" Tensor cos_sin_cache, bool is_neox) -> ()");
|
||||
ops.impl("rotary_embedding", torch::kPrivateUse1, &vllm_ascend::rotary_embedding);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C)
|
43
csrc/utils.h
Normal file
43
csrc/utils.h
Normal file
@ -0,0 +1,43 @@
|
||||
#pragma once
|
||||
|
||||
#include "kernels/types.h"
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <Python.h>
|
||||
|
||||
#define _CONCAT(A, B) A##B
|
||||
#define CONCAT(A, B) _CONCAT(A, B)
|
||||
|
||||
#define _STRINGIFY(A) #A
|
||||
#define STRINGIFY(A) _STRINGIFY(A)
|
||||
|
||||
// A version of the TORCH_LIBRARY macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
|
||||
|
||||
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
|
||||
// could be a macro instead of a literal token.
|
||||
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
|
||||
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
|
||||
|
||||
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
|
||||
// via python's import statement.
|
||||
#define REGISTER_EXTENSION(NAME) \
|
||||
PyMODINIT_FUNC CONCAT(PyInit_, NAME)() { \
|
||||
static struct PyModuleDef module = {PyModuleDef_HEAD_INIT, \
|
||||
STRINGIFY(NAME), nullptr, 0, nullptr}; \
|
||||
return PyModule_Create(&module); \
|
||||
}
|
||||
|
||||
|
||||
namespace vllm_ascend {
|
||||
AscendType get_dtype_from_torch(at::ScalarType scalarType)
|
||||
{
|
||||
if (scalarType == at::ScalarType::Float) {
|
||||
return AscendType::FP32;
|
||||
} else if (scalarType == at::ScalarType::BFloat16) {
|
||||
return AscendType::BF16;
|
||||
} else {
|
||||
return AscendType::FP16;
|
||||
}
|
||||
}
|
||||
} // namespace vllm_ascend
|
@ -3,9 +3,12 @@
|
||||
requires = [
|
||||
"setuptools>=64",
|
||||
"setuptools-scm>=8",
|
||||
"cmake>=3.26",
|
||||
"pybind11",
|
||||
"decorator",
|
||||
"pyyaml",
|
||||
"scipy",
|
||||
"torch-npu >= 2.5.1rc1"
|
||||
"torch_npu >= 2.5.1rc1",
|
||||
"torch >= 2.5.1"
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
@ -1,6 +1,7 @@
|
||||
decorator
|
||||
pyyaml
|
||||
scipy
|
||||
pybind11
|
||||
setuptools
|
||||
setuptools-scm
|
||||
numpy==1.26.4
|
||||
|
278
setup.py
278
setup.py
@ -17,12 +17,265 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
import subprocess
|
||||
import sys
|
||||
from sysconfig import get_paths
|
||||
from typing import Dict, List
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools import Extension, find_packages, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools.command.install import install
|
||||
from setuptools_scm import get_version
|
||||
|
||||
|
||||
def load_module_from_path(module_name, path):
|
||||
spec = importlib.util.spec_from_file_location(module_name, path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_or_set_default_env(cmake_args,
|
||||
env_name,
|
||||
env_variable,
|
||||
default_path=""):
|
||||
if env_variable is None:
|
||||
logging.warning(
|
||||
f"No {env_name} found in your environment, pleause try to set {env_name} "
|
||||
"if you customize the installation path of this library, otherwise default "
|
||||
"path will be adapted during build this project")
|
||||
logging.warning(f"Set default {env_name}: {default_path}")
|
||||
env_variable = default_path
|
||||
else:
|
||||
logging.info(f"Found existing {env_name}: {env_variable}")
|
||||
# cann package seems will check this environments in cmake, need write this env variable back.
|
||||
if env_name == "ASCEND_HOME_PATH":
|
||||
os.environ["ASCEND_HOME_PATH"] = env_variable
|
||||
cmake_args += [f"-D{env_name}={env_variable}"]
|
||||
return cmake_args
|
||||
|
||||
|
||||
envs = load_module_from_path("envs",
|
||||
os.path.join(ROOT_DIR, "vllm_ascend", "envs.py"))
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
cmake_lists_dir: str = ".",
|
||||
**kwargs) -> None:
|
||||
super().__init__(name, sources=[], py_limited_api=False, **kwargs)
|
||||
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)
|
||||
|
||||
|
||||
class cmake_build_ext(build_ext):
|
||||
# A dict of extension directories that have been configured.
|
||||
did_config: Dict[str, bool] = {}
|
||||
|
||||
#
|
||||
# Determine number of compilation jobs
|
||||
#
|
||||
def compute_num_jobs(self):
|
||||
# `num_jobs` is either the value of the MAX_JOBS environment variable
|
||||
# (if defined) or the number of CPUs available.
|
||||
num_jobs = envs.MAX_JOBS
|
||||
if num_jobs is not None:
|
||||
num_jobs = int(num_jobs)
|
||||
logger.info("Using MAX_JOBS=%d as the number of jobs.", num_jobs)
|
||||
else:
|
||||
try:
|
||||
# os.sched_getaffinity() isn't universally available, so fall
|
||||
# back to os.cpu_count() if we get an error here.
|
||||
num_jobs = len(os.sched_getaffinity(0))
|
||||
except AttributeError:
|
||||
num_jobs = os.cpu_count()
|
||||
num_jobs = max(1, num_jobs)
|
||||
|
||||
return num_jobs
|
||||
|
||||
#
|
||||
# Perform cmake configuration for a single extension.
|
||||
#
|
||||
def configure(self, ext: CMakeExtension) -> None:
|
||||
build_temp = self.build_temp
|
||||
os.makedirs(build_temp, exist_ok=True)
|
||||
source_dir = os.path.abspath(ROOT_DIR)
|
||||
python_executable = sys.executable
|
||||
cmake_args = ["cmake"]
|
||||
# Default use release mode to compile the csrc code
|
||||
# Turbo now support compiled with Release, Debug and RelWithDebugInfo
|
||||
if envs.CMAKE_BUILD_TYPE is None or envs.CMAKE_BUILD_TYPE not in [
|
||||
"Debug",
|
||||
"Release",
|
||||
"RelWithDebugInfo",
|
||||
]:
|
||||
envs.CMAKE_BUILD_TYPE = "Release"
|
||||
cmake_args += [f"-DCMAKE_BUILD_TYPE={envs.CMAKE_BUILD_TYPE}"]
|
||||
# Default dump the compile commands for lsp
|
||||
cmake_args += ["-DCMAKE_EXPORT_COMPILE_COMMANDS=1"]
|
||||
if envs.VERBOSE:
|
||||
cmake_args += ["-DCMAKE_VERBOSE_MAKEFILE=ON"]
|
||||
|
||||
# find ASCEND_HOME_PATH
|
||||
check_or_set_default_env(
|
||||
cmake_args,
|
||||
"ASCEND_HOME_PATH",
|
||||
envs.ASCEND_HOME_PATH,
|
||||
"/usr/local/Ascend/ascend-toolkit/latest",
|
||||
)
|
||||
|
||||
# find PYTHON_EXECUTABLE
|
||||
check_or_set_default_env(cmake_args, "PYTHON_EXECUTABLE",
|
||||
sys.executable)
|
||||
|
||||
# find PYTHON_INCLUDE_PATH
|
||||
check_or_set_default_env(cmake_args, "PYHTON_INCLUDE_PATH",
|
||||
get_paths()["include"])
|
||||
|
||||
# ccache and ninja can not be applied at ascendc kernels now
|
||||
|
||||
try:
|
||||
# if pybind11 is installed via pip
|
||||
pybind11_cmake_path = (subprocess.check_output(
|
||||
[python_executable, "-m", "pybind11",
|
||||
"--cmake"]).decode().strip())
|
||||
except subprocess.CalledProcessError as e:
|
||||
# else specify pybind11 path installed from source code on CI container
|
||||
raise RuntimeError(f"CMake configuration failed: {e}")
|
||||
|
||||
# try retrive soc version from npu-smi
|
||||
soc_command = [
|
||||
"bash",
|
||||
"-c",
|
||||
"npu-smi info | grep OK | awk '{print $3}' | head -n 1",
|
||||
]
|
||||
try:
|
||||
soc_version = subprocess.check_output(soc_command,
|
||||
text=True).strip()
|
||||
soc_version = soc_version.split("-")[0]
|
||||
soc_version = "Ascend" + soc_version
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"Retrive Soc version failed: {e}")
|
||||
|
||||
# add SOC_VERSION
|
||||
cmake_args += [f"-DSOC_VERSION={soc_version}"]
|
||||
|
||||
install_path = os.path.join(ROOT_DIR, self.build_lib)
|
||||
if isinstance(self.distribution.get_command_obj("develop"), develop):
|
||||
install_path = os.path.join(ROOT_DIR, "vllm_ascend")
|
||||
# add CMAKE_INSTALL_PATH
|
||||
cmake_args += [f"-DCMAKE_INSTALL_PREFIX={install_path}"]
|
||||
|
||||
cmake_args += [f"-DCMAKE_PREFIX_PATH={pybind11_cmake_path}"]
|
||||
|
||||
# Override the base directory for FetchContent downloads to $ROOT/.deps
|
||||
# This allows sharing dependencies between profiles,
|
||||
# and plays more nicely with sccache.
|
||||
# To override this, set the FETCHCONTENT_BASE_DIR environment variable.
|
||||
fc_base_dir = os.path.join(ROOT_DIR, ".deps")
|
||||
fc_base_dir = os.environ.get("FETCHCONTENT_BASE_DIR", fc_base_dir)
|
||||
cmake_args += ["-DFETCHCONTENT_BASE_DIR={}".format(fc_base_dir)]
|
||||
|
||||
build_tool = []
|
||||
# TODO(ganyi): ninja and ccache support for ascend c auto codegen. now we can only use make build
|
||||
# if which('ninja') is not None:
|
||||
# build_tool += ['-G', 'Ninja']
|
||||
# Default build tool to whatever cmake picks.
|
||||
|
||||
cmake_args += [source_dir]
|
||||
logging.info(f"cmake config command: {cmake_args}")
|
||||
try:
|
||||
subprocess.check_call(cmake_args, cwd=self.build_temp)
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise RuntimeError(f"CMake configuration failed: {e}")
|
||||
|
||||
subprocess.check_call(
|
||||
["cmake", ext.cmake_lists_dir, *build_tool, *cmake_args],
|
||||
cwd=self.build_temp,
|
||||
)
|
||||
|
||||
def build_extensions(self) -> None:
|
||||
if envs.COMPILE_CUSTOM_KERNELS is None:
|
||||
return
|
||||
# Ensure that CMake is present and working
|
||||
try:
|
||||
subprocess.check_output(["cmake", "--version"])
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Cannot find CMake executable: {e}")
|
||||
|
||||
# Create build directory if it does not exist.
|
||||
if not os.path.exists(self.build_temp):
|
||||
os.makedirs(self.build_temp)
|
||||
|
||||
targets = []
|
||||
|
||||
os.makedirs(os.path.join(self.build_lib, "vllm_ascend"), exist_ok=True)
|
||||
|
||||
def target_name(s: str) -> str:
|
||||
return s.removeprefix("vllm_ascend.")
|
||||
|
||||
# Build all the extensions
|
||||
for ext in self.extensions:
|
||||
self.configure(ext)
|
||||
targets.append(target_name(ext.name))
|
||||
|
||||
num_jobs = self.compute_num_jobs()
|
||||
|
||||
build_args = [
|
||||
"--build",
|
||||
".",
|
||||
f"-j={num_jobs}",
|
||||
*[f"--target={name}" for name in targets],
|
||||
]
|
||||
try:
|
||||
subprocess.check_call(["cmake", *build_args], cwd=self.build_temp)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Build library failed: {e}")
|
||||
# Install the libraries
|
||||
install_args = [
|
||||
"cmake",
|
||||
"--install",
|
||||
".",
|
||||
]
|
||||
try:
|
||||
subprocess.check_call(install_args, cwd=self.build_temp)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"Install library failed: {e}")
|
||||
|
||||
# copy back to build folder for editable build
|
||||
if isinstance(self.distribution.get_command_obj("develop"), develop):
|
||||
import shutil
|
||||
for root, _, files in os.walk(self.build_temp):
|
||||
for file in files:
|
||||
if file.endswith(".so"):
|
||||
src_path = os.path.join(root, file)
|
||||
dst_path = os.path.join(self.build_lib, "vllm_ascend",
|
||||
file)
|
||||
shutil.copy(src_path, dst_path)
|
||||
print(f"Copy: {src_path} -> {dst_path}")
|
||||
|
||||
def run(self):
|
||||
# First, run the standard build_ext command to compile the extensions
|
||||
super().run()
|
||||
|
||||
|
||||
class custom_install(install):
|
||||
|
||||
def run(self):
|
||||
self.run_command("build_ext")
|
||||
install.run(self)
|
||||
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
try:
|
||||
VERSION = get_version(write_to="vllm_ascend/_version.py")
|
||||
@ -31,6 +284,10 @@ except LookupError:
|
||||
# only checks out the commit. In this case, we set a dummy version.
|
||||
VERSION = "0.0.0"
|
||||
|
||||
ext_modules = []
|
||||
if envs.COMPILE_CUSTOM_KERNELS is not None:
|
||||
ext_modules = [CMakeExtension(name="vllm_ascend.vllm_ascend_C")]
|
||||
|
||||
|
||||
def get_path(*filepath) -> str:
|
||||
return os.path.join(ROOT_DIR, *filepath)
|
||||
@ -69,8 +326,10 @@ def get_requirements() -> List[str]:
|
||||
return requirements
|
||||
|
||||
|
||||
cmdclass = {"build_ext": cmake_build_ext, "install": custom_install}
|
||||
|
||||
setup(
|
||||
name='vllm_ascend',
|
||||
name="vllm_ascend",
|
||||
# Follow:
|
||||
# https://packaging.python.org/en/latest/specifications/version-specifiers
|
||||
version=VERSION,
|
||||
@ -95,12 +354,15 @@ setup(
|
||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
],
|
||||
packages=find_packages(exclude=("docs", "examples", "tests*")),
|
||||
packages=find_packages(exclude=("docs", "examples", "tests*", "csrc")),
|
||||
python_requires=">=3.9",
|
||||
install_requires=get_requirements(),
|
||||
ext_modules=ext_modules,
|
||||
cmdclass=cmdclass,
|
||||
extras_require={},
|
||||
entry_points={
|
||||
'vllm.platform_plugins': ["ascend = vllm_ascend:register"],
|
||||
'vllm.general_plugins':
|
||||
["ascend_enhanced_model = vllm_ascend:register_model"]
|
||||
})
|
||||
"vllm.platform_plugins": ["ascend = vllm_ascend:register"],
|
||||
"vllm.general_plugins":
|
||||
["ascend_enhanced_model = vllm_ascend:register_model"],
|
||||
},
|
||||
)
|
||||
|
204
tests/ops/test_rotary_embedding.py
Normal file
204
tests/ops/test_rotary_embedding.py
Normal file
@ -0,0 +1,204 @@
|
||||
# Adapted from
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/tests/kernels/test_rotary_embedding.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
|
||||
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
import vllm_ascend.platform # noqa: F401
|
||||
|
||||
# Only Neox style true scenario is supported for now
|
||||
IS_NEOX_STYLE = [True]
|
||||
DTYPES = [torch.half]
|
||||
HEAD_SIZES = [64, 96, 128, 256]
|
||||
ROTARY_DIMS = [None, 32] # None means rotary dim == head size
|
||||
NUM_HEADS = [17] # Arbitrary values for testing
|
||||
BATCH_SIZES = [5] # Arbitrary values for testing
|
||||
SEQ_LENS = [11, 4096] # Arbitrary values for testing
|
||||
SEEDS = [0]
|
||||
DEVICES = [f"npu:{0}"]
|
||||
# Set tolerance to 1 for quant ops
|
||||
DEFAULT_ATOL = 1e-3
|
||||
DEFAULT_RTOL = 1e-3
|
||||
|
||||
|
||||
def _apply_rotary_emb(
|
||||
x: torch.Tensor,
|
||||
cos: torch.Tensor,
|
||||
sin: torch.Tensor,
|
||||
is_neox_style: bool,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x: [num_tokens, num_heads, head_size]
|
||||
cos: [num_tokens, head_size // 2]
|
||||
sin: [num_tokens, head_size // 2]
|
||||
is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
|
||||
positional embeddings.
|
||||
"""
|
||||
cos = cos.unsqueeze(-2).to(x.dtype)
|
||||
sin = sin.unsqueeze(-2).to(x.dtype)
|
||||
if is_neox_style:
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
else:
|
||||
x1 = x[..., ::2]
|
||||
x2 = x[..., 1::2]
|
||||
o1 = x1 * cos - x2 * sin
|
||||
o2 = x2 * cos + x1 * sin
|
||||
if is_neox_style:
|
||||
return torch.cat((o1, o2), dim=-1)
|
||||
else:
|
||||
return torch.stack((o1, o2), dim=-1).flatten(-2)
|
||||
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/vllm/model_executor/layers/rotary_embedding.py
|
||||
class RotaryEmbedding(nn.Module):
|
||||
"""Original rotary positional embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position_embeddings: int,
|
||||
base: int,
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
# NOTE(woosuk): To exactly match the HF implementation, we need to
|
||||
# use CPU to compute the cache and then move it to GPU. However, we
|
||||
# create the cache on GPU for faster initialization. This may cause
|
||||
# a slight numerical difference between the HF implementation and ours.
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
offsets: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""A PyTorch-native implementation of forward()."""
|
||||
if offsets is not None:
|
||||
positions = positions + offsets
|
||||
positions = positions.flatten()
|
||||
num_tokens = positions.shape[0]
|
||||
cos_sin = self.cos_sin_cache.index_select(0, positions)
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
||||
query_shape = query.shape
|
||||
query = query.view(num_tokens, -1, self.head_size)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
|
||||
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
|
||||
|
||||
key_shape = key.shape
|
||||
key = key.view(num_tokens, -1, self.head_size)
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
|
||||
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
||||
return query, key
|
||||
|
||||
|
||||
# test with leading dimension and merge seqlen and batch_size as num_tokens
|
||||
# TODO(ganyi): open this test in the future
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"skip this test by default for now because of ci issue, will enable it in the future"
|
||||
)
|
||||
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("seq_len", SEQ_LENS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_rotary_embedding_quant_with_leading_dim(
|
||||
is_neox_style: bool,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
rotary_dim: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
max_position: int = 8192,
|
||||
base: int = 10000,
|
||||
) -> None:
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
|
||||
torch.set_default_device(device)
|
||||
if rotary_dim is None:
|
||||
rotary_dim = head_size
|
||||
rope = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
rope = rope.to(dtype=dtype)
|
||||
num_tokens = batch_size * seq_len
|
||||
positions = torch.randint(0, max_position, (batch_size * seq_len, ))
|
||||
qkv_tensor = torch.randn(num_tokens,
|
||||
num_heads * head_size * 3,
|
||||
dtype=dtype)
|
||||
query, key, _ = qkv_tensor.split(
|
||||
[num_heads * head_size, num_heads * head_size, num_heads * head_size],
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
ref_query, ref_key = rope.forward_native(positions, query, key)
|
||||
torch.ops._C.rotary_embedding(
|
||||
positions,
|
||||
query,
|
||||
key,
|
||||
rope.head_size,
|
||||
rope.cos_sin_cache,
|
||||
rope.is_neox_style,
|
||||
)
|
||||
|
||||
# Compare the results.
|
||||
torch.testing.assert_close(query,
|
||||
ref_query,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
||||
torch.testing.assert_close(key,
|
||||
ref_key,
|
||||
atol=DEFAULT_ATOL,
|
||||
rtol=DEFAULT_RTOL)
|
25
vllm_ascend/envs.py
Normal file
25
vllm_ascend/envs.py
Normal file
@ -0,0 +1,25 @@
|
||||
import os
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# max compile thread num
|
||||
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
|
||||
"CMAKE_BUILD_TYPE": lambda: os.getenv("CMAKE_BUILD_TYPE"),
|
||||
"COMPILE_CUSTOM_KERNELS":
|
||||
lambda: os.getenv("COMPILE_CUSTOM_KERNELS", None),
|
||||
# If set, vllm-ascend will print verbose logs during compliation
|
||||
"VERBOSE": lambda: bool(int(os.getenv('VERBOSE', '0'))),
|
||||
"ASCEND_HOME_PATH": lambda: os.getenv("ASCEND_HOME_PATH", None),
|
||||
"LD_LIBRARY_PATH": lambda: os.getenv("LD_LIBRARY_PATH", None),
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str):
|
||||
# lazy evaluation of environment variables
|
||||
if name in env_variables:
|
||||
return env_variables[name]()
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__():
|
||||
return list(env_variables.keys())
|
@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
@ -23,6 +24,19 @@ import torch_npu # noqa: F401
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationLevel
|
||||
from vllm.logger import init_logger
|
||||
|
||||
try:
|
||||
# register custom ops into torch_library here
|
||||
import vllm_ascend.vllm_ascend_C # type: ignore # noqa: F401
|
||||
|
||||
except ImportError as e:
|
||||
if not str(
|
||||
e
|
||||
) == "dynamic module does not define module export function (PyInit_vllm_ascend_C)":
|
||||
logging.warning(
|
||||
"Warning: Failed to register custom ops, all custom ops will be disabled"
|
||||
)
|
||||
|
||||
from vllm.platforms import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
Reference in New Issue
Block a user