mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Add Custom Kernels For LoRA Performance (#2325)
### What this PR does / why we need it?
Add two custom operators (sgmv_shrink and sgmv_expand) to address the
performance issues of LoRA. Meanwhile, enable the graph mode for LoRA
operators to enter ACL, so as to improve the model inference
performance.
### Does this PR introduce _any_ user-facing change?
no user-facing change
### How was this patch tested?
Based on the actual test of the QWen2.5 7B model using vllm-ascend
version v0.9.2.rc1, in acl graph mode, the TTFT, TPOT and throughput
have increased by about 100%.
Signed-off-by: liuchn <909698896@qq.com>
- vLLM version: v0.10.0
- vLLM main:
1f83e7d849
---------
Signed-off-by: liuchn <909698896@qq.com>
Co-authored-by: liuchn <909698896@qq.com>
This commit is contained in:
383
csrc/kernels/sgmv_expand.cpp
Normal file
383
csrc/kernels/sgmv_expand.cpp
Normal file
@ -0,0 +1,383 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "types.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
class SGMVExpand {
|
||||
public:
|
||||
using X_T = float;
|
||||
using W_T = scalar_t;
|
||||
using Y_T = scalar_t;
|
||||
|
||||
static constexpr uint64_t LORA_RANK_8 = 8;
|
||||
static constexpr uint64_t LORA_RANK_16 = 16;
|
||||
static constexpr uint64_t LORA_RANK_32 = 32;
|
||||
static constexpr uint64_t LORA_RANK_64 = 64;
|
||||
static constexpr uint64_t SUPPORTED_RANKS[] = {LORA_RANK_8, LORA_RANK_16, LORA_RANK_32, LORA_RANK_64};
|
||||
static constexpr int32_t BUFFER_NUM = 2;
|
||||
|
||||
// The vector unit reads 8 blocks (32 bytes each and 256 bytes in total) of contiguous data each time.
|
||||
static constexpr int32_t NUM_BYTES_PER_REPEAT = 256;
|
||||
static constexpr int32_t NUM_BLOCKS_PER_REPEAT = 8;
|
||||
// The maximum number of elements in a single iteration is 256 / sizeof(intermediate data type).
|
||||
static constexpr int32_t NUM_ELEMENTS_PER_REPEAT = NUM_BYTES_PER_REPEAT / sizeof(float);
|
||||
// Mask is used to control the elements that participate in computation in each iteration.
|
||||
static constexpr int32_t MASK_COUNT = NUM_BYTES_PER_REPEAT / sizeof(float);
|
||||
// Refer to numOutputElementsPerInputTile_ initialization for the constraints on the following constants.
|
||||
static constexpr int32_t W_IN_TILE_NUM_ELEMENTS = 8192;
|
||||
static constexpr int32_t Y_OUT_TILE_NUM_ELEMENTS = 4096;
|
||||
static constexpr int32_t BLOCK_REDUCE_NUM_REPEATS = W_IN_TILE_NUM_ELEMENTS / NUM_ELEMENTS_PER_REPEAT;
|
||||
// BlockReduceSum would generate(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT)floats.
|
||||
// So need to read them all and apply PairReduceSum
|
||||
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_16 =
|
||||
(BLOCK_REDUCE_NUM_REPEATS * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
|
||||
// The second PairReduceSum for rank=32, needs half of the repetition that happened for rank=16.
|
||||
// Same for rank=64, we do not support ranks greater than 64.
|
||||
static constexpr int32_t PAIR_REDUCE_NUM_REPEATS_32 = (PAIR_REDUCE_NUM_REPEATS_16 + 1) / 2;
|
||||
|
||||
public:
|
||||
__aicore__ inline SGMVExpand(AscendC::TPipe* pipe) : pipe_(pipe) {}
|
||||
|
||||
__aicore__ inline void Init(__gm__ void* x, __gm__ void* weight, __gm__ void* loraIndices,
|
||||
__gm__ void* seqLen, __gm__ void* yIn, __gm__ void* yOut,
|
||||
uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
|
||||
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
|
||||
{
|
||||
batchSize_ = batchSize;
|
||||
numTokensPerCore_ = numTokensPerCore;
|
||||
maxLoRARank_ = maxLoRARank;
|
||||
outputHiddenDim_ = outputHiddenDim;
|
||||
sliceOffset_ = sliceOffset;
|
||||
outputFullDim_ = outputFullDim;
|
||||
singleLoRAWeightLen_ = maxLoRARank_ * outputHiddenDim_;
|
||||
|
||||
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
|
||||
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
|
||||
yInGm_.SetGlobalBuffer((__gm__ Y_T *)yIn);
|
||||
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)yOut);
|
||||
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices);
|
||||
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen);
|
||||
|
||||
pipe_->InitBuffer(inQueueX_, 1, NUM_ELEMENTS_PER_REPEAT * sizeof(X_T));
|
||||
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, W_IN_TILE_NUM_ELEMENTS * sizeof(W_T));
|
||||
pipe_->InitBuffer(inQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
|
||||
pipe_->InitBuffer(outQueueY_, BUFFER_NUM, Y_OUT_TILE_NUM_ELEMENTS * sizeof(Y_T));
|
||||
|
||||
pipe_->InitBuffer(dupBufferX_, NUM_ELEMENTS_PER_REPEAT * sizeof(float));
|
||||
pipe_->InitBuffer(tmpBufferW_, W_IN_TILE_NUM_ELEMENTS * sizeof(float));
|
||||
pipe_->InitBuffer(inBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
|
||||
pipe_->InitBuffer(tmpBufferY_, Y_OUT_TILE_NUM_ELEMENTS * sizeof(float));
|
||||
|
||||
// Each compute iteration would generate not one, but several output elements.
|
||||
// Therefore, the following variable would determine how many output elements are calculated in each iteration.
|
||||
numOutputElementsPerInputTile_ = BLOCK_REDUCE_NUM_REPEATS * (NUM_ELEMENTS_PER_REPEAT / maxLoRARank_);
|
||||
numStreamInPerOutputTile_ = Y_OUT_TILE_NUM_ELEMENTS / numOutputElementsPerInputTile_;
|
||||
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
int64_t blockIdx = AscendC::GetBlockIdx();
|
||||
int64_t startIdx = blockIdx * numTokensPerCore_;
|
||||
int64_t endIdx = startIdx + numTokensPerCore_;
|
||||
if (endIdx > batchSize_) {
|
||||
endIdx = batchSize_;
|
||||
}
|
||||
for (int64_t idx = startIdx; idx < endIdx; idx++) {
|
||||
yOffset_ = outputFullDim_ * idx + sliceOffset_;
|
||||
|
||||
// Set up LoRA index
|
||||
CopyInIndex(idx);
|
||||
if (reqLoRAIndex_ < 0) {
|
||||
continue;
|
||||
}
|
||||
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
|
||||
|
||||
CopyInX(idx);
|
||||
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
|
||||
for (int32_t i = 0; i < numStreamOut; i++) {
|
||||
CopyInY(i);
|
||||
for (int32_t j = 0; j < numStreamInPerOutputTile_; j++) {
|
||||
CopyInW(i * numStreamInPerOutputTile_ + j);
|
||||
Compute(j * numOutputElementsPerInputTile_);
|
||||
}
|
||||
ScaleOutput();
|
||||
CopyOut(i);
|
||||
}
|
||||
ComputeLastIteration();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
__aicore__ inline void CopyInIndex(const int64_t idx)
|
||||
{
|
||||
// Look up the LoRA index
|
||||
int64_t weightIdx = idx;
|
||||
uint64_t i = 0;
|
||||
for (; i < seqLenGm_.GetSize(); i++) {
|
||||
int64_t repeatValue = seqLenGm_.GetValue(i);
|
||||
if (weightIdx >= repeatValue) {
|
||||
weightIdx -= repeatValue;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1;
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeLastIteration()
|
||||
{
|
||||
int32_t remainingY = outputHiddenDim_ % Y_OUT_TILE_NUM_ELEMENTS;
|
||||
if (remainingY == 0) {
|
||||
return;
|
||||
}
|
||||
int32_t numStreamOut = outputHiddenDim_ / Y_OUT_TILE_NUM_ELEMENTS;
|
||||
int32_t remainingW = remainingY * maxLoRARank_;
|
||||
int32_t numCompleteWTileInForLastIteration = remainingW / W_IN_TILE_NUM_ELEMENTS;
|
||||
int32_t remainingWForLastRepeat = remainingW % W_IN_TILE_NUM_ELEMENTS;
|
||||
|
||||
CopyInY(numStreamOut, remainingY);
|
||||
|
||||
int32_t outputIdx = 0;
|
||||
for (outputIdx = 0; outputIdx < numCompleteWTileInForLastIteration; outputIdx++) {
|
||||
CopyInW(numStreamOut * numStreamInPerOutputTile_ + outputIdx);
|
||||
Compute(outputIdx * numOutputElementsPerInputTile_);
|
||||
}
|
||||
|
||||
if (remainingWForLastRepeat != 0) {
|
||||
CopyInW(numStreamOut * numStreamInPerOutputTile_ + numCompleteWTileInForLastIteration,
|
||||
remainingWForLastRepeat);
|
||||
int32_t lastRepeatCount = remainingWForLastRepeat / NUM_ELEMENTS_PER_REPEAT;
|
||||
int32_t pairReduceRepeat16 =
|
||||
(lastRepeatCount * NUM_BLOCKS_PER_REPEAT + NUM_ELEMENTS_PER_REPEAT - 1) / NUM_ELEMENTS_PER_REPEAT;
|
||||
int32_t pairReduceRepeat32 = (pairReduceRepeat16 + 1) / 2;
|
||||
int32_t lastComputeOutputElement = outputIdx * numOutputElementsPerInputTile_;
|
||||
Compute(lastComputeOutputElement, lastRepeatCount, pairReduceRepeat16, pairReduceRepeat32);
|
||||
}
|
||||
|
||||
ScaleOutput(remainingY);
|
||||
CopyOut(numStreamOut, remainingY);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInX(const int64_t idx)
|
||||
{
|
||||
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
|
||||
if constexpr (std::is_same_v<X_T, float>) {
|
||||
DataCopy(xLocal, xGm_[maxLoRARank_ * idx], maxLoRARank_);
|
||||
} else {
|
||||
uint16_t blockLen = static_cast<uint16_t>(maxLoRARank_ * sizeof(X_T));
|
||||
DataCopyPad(xLocal, xGm_[maxLoRARank_ * idx], {1, blockLen, 0, 0}, {});
|
||||
}
|
||||
inQueueX_.EnQue(xLocal);
|
||||
xLocal = inQueueX_.DeQue<X_T>();
|
||||
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
|
||||
|
||||
// As we are generating multiple output elements with one API invocation,
|
||||
// we need to duplicate the X vector multiple times to fill one NUM_BYTES_PER_REPEAT
|
||||
if constexpr (std::is_same_v<X_T, float>) {
|
||||
for (int32_t i = 0; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
|
||||
for (int32_t j = 0; j < maxLoRARank_; j++) {
|
||||
float entry = xLocal.GetValue(j);
|
||||
xDup.SetValue(i + j, entry);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Cast(xDup, xLocal, AscendC::RoundMode::CAST_NONE, maxLoRARank_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
for (int32_t i = maxLoRARank_; i < NUM_ELEMENTS_PER_REPEAT; i += maxLoRARank_) {
|
||||
for (int32_t j = 0; j < maxLoRARank_; j++) {
|
||||
float entry = xDup.GetValue(j);
|
||||
xDup.SetValue(i + j, entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
inQueueX_.FreeTensor(xLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInY(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
|
||||
{
|
||||
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.AllocTensor<Y_T>();
|
||||
DataCopy(yInLocal, yInGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], numElements);
|
||||
inQueueY_.EnQue(yInLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInW(int32_t progress, int32_t numElements = W_IN_TILE_NUM_ELEMENTS)
|
||||
{
|
||||
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
|
||||
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + progress * W_IN_TILE_NUM_ELEMENTS], numElements);
|
||||
inQueueW_.EnQue(wLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void ScaleOutput(int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
|
||||
{
|
||||
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
|
||||
AscendC::LocalTensor<Y_T> yInLocal = inQueueY_.DeQue<Y_T>();
|
||||
AscendC::LocalTensor<float> yInLocalFP32 = inBufferY_.Get<float>();
|
||||
Cast(yInLocalFP32, yInLocal, AscendC::RoundMode::CAST_NONE, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
inQueueY_.FreeTensor(yInLocal);
|
||||
|
||||
Add(yLocal, yLocal, yInLocalFP32, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
|
||||
Cast(yOutLocal, yLocal, AscendC::RoundMode::CAST_RINT, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
outQueueY_.EnQue<Y_T>(yOutLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute(int32_t progress,
|
||||
int32_t blockReduceRepeatCount=BLOCK_REDUCE_NUM_REPEATS,
|
||||
int32_t pairReduceRepeat16=PAIR_REDUCE_NUM_REPEATS_16,
|
||||
int32_t pairReduceRepeat32=PAIR_REDUCE_NUM_REPEATS_32)
|
||||
{
|
||||
AscendC::LocalTensor<float> yLocal = tmpBufferY_.Get<float>();
|
||||
AscendC::LocalTensor<float> xDup = dupBufferX_.Get<float>();
|
||||
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
|
||||
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
|
||||
|
||||
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, MASK_COUNT, blockReduceRepeatCount, castParams_);
|
||||
pipe_barrier(PIPE_V);
|
||||
inQueueW_.FreeTensor(wLocal);
|
||||
|
||||
Mul(wTmpTensor, xDup, wTmpTensor, MASK_COUNT, blockReduceRepeatCount, dotProductParams_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
if (maxLoRARank_ == LORA_RANK_8) {
|
||||
BlockReduceSum(yLocal[progress], wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
} else if (maxLoRARank_ == LORA_RANK_16) {
|
||||
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
} else if (maxLoRARank_ == LORA_RANK_32) {
|
||||
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
PairReduceSum(wTmpTensor, wTmpTensor, pairReduceRepeat16, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
PairReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat32, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
} else if (maxLoRARank_ == LORA_RANK_64) {
|
||||
BlockReduceSum(wTmpTensor, wTmpTensor, blockReduceRepeatCount, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
BlockReduceSum(yLocal[progress], wTmpTensor, pairReduceRepeat16, MASK_COUNT,
|
||||
reduceSumParams_.dstRepStride, reduceSumParams_.srcBlkStride, reduceSumParams_.srcRepStride);
|
||||
pipe_barrier(PIPE_V);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOut(int32_t progress, int32_t numElements = Y_OUT_TILE_NUM_ELEMENTS)
|
||||
{
|
||||
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
|
||||
DataCopy(yOutGm_[yOffset_ + progress * Y_OUT_TILE_NUM_ELEMENTS], yOutLocal, numElements);
|
||||
outQueueY_.FreeTensor(yOutLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::TPipe* pipe_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueY_, inQueueW_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, 1> inQueueX_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> outQueueY_;
|
||||
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferW_, dupBufferX_, inBufferY_, tmpBufferY_;
|
||||
AscendC::GlobalTensor<X_T> xGm_;
|
||||
AscendC::GlobalTensor<W_T> wGm_;
|
||||
AscendC::GlobalTensor<Y_T> yInGm_;
|
||||
AscendC::GlobalTensor<Y_T> yOutGm_;
|
||||
AscendC::GlobalTensor<int64_t> loraIndicesGm_;
|
||||
AscendC::GlobalTensor<int64_t> seqLenGm_;
|
||||
uint32_t batchSize_;
|
||||
uint32_t numTokensPerCore_;
|
||||
uint32_t maxLoRARank_;
|
||||
uint32_t outputHiddenDim_;
|
||||
uint32_t sliceOffset_;
|
||||
uint32_t outputFullDim_;
|
||||
uint32_t singleLoRAWeightLen_;
|
||||
int64_t reqLoRAIndex_;
|
||||
uint64_t reqLoRAWeightOffset_;
|
||||
uint32_t numOutputElementsPerInputTile_;
|
||||
uint32_t numStreamInPerOutputTile_;
|
||||
uint64_t yOffset_;
|
||||
|
||||
// The block stride is set to 1, and 8 blocks in the same repeat are processed continuously.
|
||||
// The repeat stride is 8, so the vector unit reads 8 consecutive blocks in the first repeat,
|
||||
// reads next 8 consecutive blocks in the second repeat.
|
||||
AscendC::UnaryRepeatParams castParams_ = {1, 1, 8, 4};
|
||||
|
||||
// For each repeat in BlockReduceSum and PairReduceSum we should move forward only one block,
|
||||
// so we set dstRepStride = 1
|
||||
AscendC::UnaryRepeatParams reduceSumParams_ = {1, 1, 1, 8};
|
||||
|
||||
// When the repeat stride is 0, the vector unit repeatedly reads and computes the first 8 consecutive blocks.
|
||||
// For xDup we repeatedly use it, so we set src0RepStride = 0
|
||||
AscendC::BinaryRepeatParams dotProductParams_ = {1, 1, 1, 8, 0, 8};
|
||||
|
||||
};
|
||||
|
||||
#define SGMV_EXPAND_TYPE_DECLARE(TYPE) \
|
||||
extern "C" __global__ __aicore__ void sgmv_expand_##TYPE(__gm__ void* x, __gm__ void* weight, \
|
||||
__gm__ void* loraIndices, __gm__ void* seqLen, \
|
||||
__gm__ void* yIn, __gm__ void* yOut, \
|
||||
uint32_t batchSize, uint32_t numTokensPerCore, \
|
||||
uint32_t maxLoRARank, uint32_t outputHiddenDim, \
|
||||
uint32_t sliceOffset, uint32_t outputFullDim) \
|
||||
{ \
|
||||
AscendC::TPipe pipe; \
|
||||
SGMVExpand<TYPE> op(&pipe); \
|
||||
op.Init(x, weight, loraIndices, seqLen, yIn, yOut, batchSize, numTokensPerCore, maxLoRARank, \
|
||||
outputHiddenDim, sliceOffset, outputFullDim); \
|
||||
op.Process(); \
|
||||
}
|
||||
|
||||
// declare all dtype kernel
|
||||
SGMV_EXPAND_TYPE_DECLARE(half)
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
SGMV_EXPAND_TYPE_DECLARE(bfloat16_t)
|
||||
#endif
|
||||
|
||||
namespace vllm_ascend {
|
||||
extern void sgmv_expand_impl(AscendType type, void* stream, void* x, void* weight, void* loraIndices, void* seqLen,
|
||||
void* yIn, void* yOut, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t maxLoRARank,
|
||||
uint32_t outputHiddenDim, uint32_t sliceOffset, uint32_t outputFullDim)
|
||||
{
|
||||
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
|
||||
if (type == AscendType::FP16) {
|
||||
sgmv_expand_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize,
|
||||
numTokensPerCore, maxLoRARank, outputHiddenDim, sliceOffset,
|
||||
outputFullDim);
|
||||
} else if (type == AscendType::BF16) {
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
sgmv_expand_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, yIn, yOut, batchSize,
|
||||
numTokensPerCore, maxLoRARank, outputHiddenDim,
|
||||
sliceOffset, outputFullDim);
|
||||
#endif
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
267
csrc/kernels/sgmv_shrink.cpp
Normal file
267
csrc/kernels/sgmv_shrink.cpp
Normal file
@ -0,0 +1,267 @@
|
||||
/*
|
||||
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "types.h"
|
||||
|
||||
template <typename scalar_t>
|
||||
class SGMVShrink {
|
||||
public:
|
||||
using X_T = scalar_t;
|
||||
using W_T = scalar_t;
|
||||
using Y_T = float;
|
||||
|
||||
static constexpr uint64_t BUFFER_NUM = 1;
|
||||
static constexpr uint64_t TILE_LENGTH = 11776; // optimal performance tile length
|
||||
|
||||
public:
|
||||
__aicore__ inline SGMVShrink(AscendC::TPipe *pipe) : pipe_(pipe) {}
|
||||
__aicore__ inline void Init(__gm__ void *x, __gm__ void *weight, __gm__ void *loraIndices, __gm__ void *seqLen,
|
||||
__gm__ void *y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
|
||||
uint32_t maxLoRARank, float scale)
|
||||
{
|
||||
batchSize_ = batchSize;
|
||||
numTokensPerCore_ = numTokensPerCore;
|
||||
inputHiddenDim_ = inputHiddenDim;
|
||||
maxLoRARank_ = maxLoRARank;
|
||||
scale_ = scale;
|
||||
singleLoRAWeightLen_ = inputHiddenDim_ * maxLoRARank_;
|
||||
incremental_ = inputHiddenDim_ > TILE_LENGTH;
|
||||
|
||||
xGm_.SetGlobalBuffer((__gm__ X_T *)x);
|
||||
yOutGm_.SetGlobalBuffer((__gm__ Y_T *)y);
|
||||
wGm_.SetGlobalBuffer((__gm__ W_T *)weight);
|
||||
loraIndicesGm_.SetGlobalBuffer((__gm__ int64_t *)loraIndices);
|
||||
seqLenGm_.SetGlobalBuffer((__gm__ int64_t *)seqLen);
|
||||
|
||||
pipe_->InitBuffer(inQueueX_, BUFFER_NUM, TILE_LENGTH * sizeof(X_T));
|
||||
pipe_->InitBuffer(inQueueW_, BUFFER_NUM, TILE_LENGTH * sizeof(W_T));
|
||||
pipe_->InitBuffer(tmpBufferX_, TILE_LENGTH * sizeof(float));
|
||||
pipe_->InitBuffer(tmpBufferW_, TILE_LENGTH * sizeof(float));
|
||||
|
||||
pipe_->InitBuffer(outQueueY_, 1, maxLoRARank_ * sizeof(Y_T));
|
||||
pipe_->InitBuffer(outBufferY_, maxLoRARank_ * sizeof(float));
|
||||
}
|
||||
|
||||
__aicore__ inline void Process()
|
||||
{
|
||||
int64_t blockIdx = AscendC::GetBlockIdx();
|
||||
int64_t startIdx = blockIdx * numTokensPerCore_;
|
||||
int64_t endIdx = startIdx + numTokensPerCore_;
|
||||
if (endIdx > batchSize_) {
|
||||
endIdx = batchSize_;
|
||||
}
|
||||
for (int64_t idx = startIdx; idx < endIdx; idx++) {
|
||||
// set up LoRA index
|
||||
CopyInIndex(idx);
|
||||
if (reqLoRAIndex_ < 0) {
|
||||
continue;
|
||||
}
|
||||
reqLoRAWeightOffset_ = reqLoRAIndex_ * singleLoRAWeightLen_;
|
||||
|
||||
if (incremental_) {
|
||||
ProcessImpl<true>(idx);
|
||||
} else {
|
||||
ProcessImpl<false>(idx);
|
||||
}
|
||||
|
||||
ScaleOutput();
|
||||
CopyOut(idx);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool INCREMENTAL_MODE>
|
||||
__aicore__ inline void ProcessImpl(const int64_t idx)
|
||||
{
|
||||
AscendC::LocalTensor<float> yOutLocal = outBufferY_.Get<float>();
|
||||
if constexpr (!INCREMENTAL_MODE) {
|
||||
CopyInX(idx, 0, inputHiddenDim_);
|
||||
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
|
||||
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
|
||||
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, inputHiddenDim_);
|
||||
pipe_barrier(PIPE_V);
|
||||
inQueueX_.FreeTensor(xLocal);
|
||||
}
|
||||
|
||||
for (int i = 0; i < maxLoRARank_; i++) {
|
||||
float acc(0);
|
||||
for (int32_t j = 0; j < inputHiddenDim_ / TILE_LENGTH; j++) {
|
||||
if constexpr (INCREMENTAL_MODE) {
|
||||
CopyInX(idx, j);
|
||||
}
|
||||
CopyInW(i, j);
|
||||
Compute<INCREMENTAL_MODE>(acc);
|
||||
}
|
||||
CopyAndComputeLastIteration<INCREMENTAL_MODE>(idx, i, acc);
|
||||
yOutLocal.SetValue(i, acc);
|
||||
}
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInIndex(const int64_t idx)
|
||||
{
|
||||
// look up the LoRA index
|
||||
int64_t weightIdx = idx;
|
||||
uint64_t i = 0;
|
||||
for (; i < seqLenGm_.GetSize(); i++) {
|
||||
int64_t repeatValue = seqLenGm_.GetValue(i);
|
||||
if (weightIdx >= repeatValue) {
|
||||
weightIdx -= repeatValue;
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
reqLoRAIndex_ = (i < seqLenGm_.GetSize()) ? loraIndicesGm_.GetValue(i) : -1;
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInX(const int64_t idx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
|
||||
{
|
||||
AscendC::LocalTensor<X_T> xLocal = inQueueX_.AllocTensor<X_T>();
|
||||
DataCopy(xLocal, xGm_[inputHiddenDim_ * idx + colIdx * TILE_LENGTH], numElements);
|
||||
inQueueX_.EnQue(xLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyInW(int32_t rowIdx, int32_t colIdx, int32_t numElements = TILE_LENGTH)
|
||||
{
|
||||
AscendC::LocalTensor<W_T> wLocal = inQueueW_.AllocTensor<W_T>();
|
||||
DataCopy(wLocal, wGm_[reqLoRAWeightOffset_ + rowIdx * inputHiddenDim_ + colIdx * TILE_LENGTH], numElements);
|
||||
inQueueW_.EnQue(wLocal);
|
||||
}
|
||||
|
||||
template <bool INCREMENTAL_MODE>
|
||||
__aicore__ inline void Compute(float &acc, int32_t numElements = TILE_LENGTH)
|
||||
{
|
||||
AscendC::LocalTensor<W_T> wLocal = inQueueW_.DeQue<W_T>();
|
||||
AscendC::LocalTensor<float> xTmpTensor = tmpBufferX_.Get<float>();
|
||||
AscendC::LocalTensor<float> wTmpTensor = tmpBufferW_.Get<float>();
|
||||
|
||||
if constexpr (INCREMENTAL_MODE) {
|
||||
AscendC::LocalTensor<X_T> xLocal = inQueueX_.DeQue<X_T>();
|
||||
Cast(xTmpTensor, xLocal, AscendC::RoundMode::CAST_NONE, numElements);
|
||||
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
inQueueX_.FreeTensor(xLocal);
|
||||
inQueueW_.FreeTensor(wLocal);
|
||||
} else {
|
||||
Cast(wTmpTensor, wLocal, AscendC::RoundMode::CAST_NONE, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
inQueueW_.FreeTensor(wLocal);
|
||||
}
|
||||
// dot product of the one tile of X and W
|
||||
Mul(wTmpTensor, xTmpTensor, wTmpTensor, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
// reduce sum generate one number, which is the summation of all the dot product
|
||||
ReduceSum<float>(wTmpTensor, wTmpTensor, wTmpTensor, numElements);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
acc += wTmpTensor.GetValue(0);
|
||||
}
|
||||
|
||||
template <bool INCREMENTAL_MODE>
|
||||
__aicore__ inline void CopyAndComputeLastIteration(const int64_t idx, int32_t rowIdx, float &acc)
|
||||
{
|
||||
int32_t colIdx = inputHiddenDim_ / TILE_LENGTH;
|
||||
int32_t remaining = inputHiddenDim_ % TILE_LENGTH;
|
||||
if (remaining == 0) {
|
||||
return;
|
||||
}
|
||||
if constexpr (INCREMENTAL_MODE) {
|
||||
CopyInX(idx, colIdx, remaining);
|
||||
}
|
||||
CopyInW(rowIdx, colIdx, remaining);
|
||||
Compute<INCREMENTAL_MODE>(acc, remaining);
|
||||
}
|
||||
|
||||
__aicore__ inline void ScaleOutput()
|
||||
{
|
||||
AscendC::LocalTensor<float> yLocal = outBufferY_.Get<float>();
|
||||
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.AllocTensor<Y_T>();
|
||||
|
||||
Muls(yOutLocal, yLocal, scale_, maxLoRARank_);
|
||||
pipe_barrier(PIPE_V);
|
||||
|
||||
outQueueY_.EnQue<Y_T>(yOutLocal);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyOut(const int64_t idx)
|
||||
{
|
||||
AscendC::LocalTensor<Y_T> yOutLocal = outQueueY_.DeQue<Y_T>();
|
||||
DataCopy(yOutGm_[maxLoRARank_ * idx], yOutLocal, maxLoRARank_);
|
||||
outQueueY_.FreeTensor(yOutLocal);
|
||||
}
|
||||
|
||||
private:
|
||||
AscendC::TPipe *pipe_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECIN, BUFFER_NUM> inQueueX_, inQueueW_;
|
||||
AscendC::TQue<AscendC::QuePosition::VECOUT, 1> outQueueY_;
|
||||
AscendC::TBuf<AscendC::QuePosition::VECCALC> tmpBufferX_, tmpBufferW_, outBufferY_;
|
||||
AscendC::GlobalTensor<X_T> xGm_;
|
||||
AscendC::GlobalTensor<W_T> wGm_;
|
||||
AscendC::GlobalTensor<int64_t> loraIndicesGm_;
|
||||
AscendC::GlobalTensor<int64_t> seqLenGm_;
|
||||
AscendC::GlobalTensor<Y_T> yOutGm_;
|
||||
uint32_t batchSize_;
|
||||
uint32_t numTokensPerCore_;
|
||||
uint32_t inputHiddenDim_;
|
||||
uint32_t maxLoRARank_;
|
||||
float scale_;
|
||||
uint32_t singleLoRAWeightLen_;
|
||||
int64_t reqLoRAIndex_;
|
||||
uint64_t reqLoRAWeightOffset_;
|
||||
bool incremental_;
|
||||
};
|
||||
|
||||
#define SGMV_SHRINK_TYPE_DECLARE(TYPE) \
|
||||
extern "C" __global__ __aicore__ void sgmv_shrink_##TYPE(__gm__ void* x, __gm__ void* weight, \
|
||||
__gm__ void* loraIndices, __gm__ void* seqLen, \
|
||||
__gm__ void* y, uint32_t batchSize, \
|
||||
uint32_t numTokensPerCore, uint32_t inputHiddenDim, \
|
||||
uint32_t maxLoRARank, float scale) \
|
||||
{ \
|
||||
AscendC::TPipe pipe; \
|
||||
SGMVShrink<TYPE> op(&pipe); \
|
||||
op.Init(x, weight, loraIndices, seqLen,y, batchSize, numTokensPerCore, inputHiddenDim, maxLoRARank, scale); \
|
||||
op.Process(); \
|
||||
}
|
||||
|
||||
// declare all dtype kernel
|
||||
SGMV_SHRINK_TYPE_DECLARE(half)
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
SGMV_SHRINK_TYPE_DECLARE(bfloat16_t)
|
||||
#endif
|
||||
|
||||
namespace vllm_ascend {
|
||||
extern void sgmv_shrink_impl(AscendType type, void* stream, void* x, void* weight, void* loraIndices, void* seqLen,
|
||||
void* y, uint32_t batchSize, uint32_t numTokensPerCore, uint32_t inputHiddenDim,
|
||||
uint32_t maxLoRARank, float scale)
|
||||
{
|
||||
uint32_t blockDim = (batchSize + numTokensPerCore - 1) / numTokensPerCore;
|
||||
if (type == AscendType::FP16) {
|
||||
sgmv_shrink_half<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, y, batchSize,
|
||||
numTokensPerCore, inputHiddenDim, maxLoRARank,
|
||||
scale);
|
||||
} else if (type == AscendType::BF16) {
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
sgmv_shrink_bfloat16_t<<<blockDim, nullptr, stream>>>(x, weight, loraIndices, seqLen, y, batchSize,
|
||||
numTokensPerCore, inputHiddenDim, maxLoRARank,
|
||||
scale);
|
||||
#endif
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
30
csrc/ops.h
30
csrc/ops.h
@ -88,4 +88,34 @@ namespace vllm_ascend {
|
||||
uint32_t output_hidden_dim,
|
||||
uint32_t slice_offset,
|
||||
uint32_t output_full_dim);
|
||||
|
||||
extern void sgmv_shrink_impl(
|
||||
AscendType type,
|
||||
void *stream,
|
||||
void *x,
|
||||
void *weight,
|
||||
void *loraIndices,
|
||||
void *seqLen,
|
||||
void *y,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_tokens_per_core,
|
||||
uint32_t input_hidden_dim,
|
||||
uint32_t lora_rank,
|
||||
float scale);
|
||||
|
||||
extern void sgmv_expand_impl(
|
||||
AscendType type,
|
||||
void *stream,
|
||||
void *x,
|
||||
void *weight,
|
||||
void *loraIndices,
|
||||
void *seqLen,
|
||||
void *y,
|
||||
void *y_out,
|
||||
uint32_t batch_size,
|
||||
uint32_t num_tokens_per_core,
|
||||
uint32_t lora_rank,
|
||||
uint32_t output_hidden_dim,
|
||||
uint32_t slice_offset,
|
||||
uint32_t output_full_dim);
|
||||
}
|
||||
|
@ -294,6 +294,87 @@ at::Tensor bgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, a
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
|
||||
void sgmv_shrink(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
||||
at::Tensor &y, double scale)
|
||||
{
|
||||
at::ScalarType scalar_type = x.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(x.size(1) > y.size(1), "hidden in should be greater than hidden out");
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int input_hidden_token = x.size(1);
|
||||
uint32_t lora_rank = y.size(1);
|
||||
float scale_f = static_cast<float>(scale);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_shrink");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr,
|
||||
batch_size, input_hidden_token, lora_rank, scale_f]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_shrink_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, batch_size,
|
||||
num_tokens_per_core, input_hidden_token, lora_rank, scale_f);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return;
|
||||
}
|
||||
|
||||
at::Tensor sgmv_expand(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
||||
at::Tensor &y, int64_t slice_offset, int64_t slice_size)
|
||||
{
|
||||
at::ScalarType scalar_type = y.scalar_type();
|
||||
TORCH_CHECK(scalar_type == torch::kHalf || scalar_type == torch::kBFloat16, "only support half and bf16");
|
||||
TORCH_CHECK(x.dim() == 2, "x should be [batch_size, hidden_in]");
|
||||
TORCH_CHECK(weight.dim() == 3 || weight.dim() == 4,
|
||||
"weight should be [num_loras, hidden_out, hidden_in] or [num_loras, 1, hidden_out, hidden_in]");
|
||||
TORCH_CHECK(y.dim() == 2, "y should be [batch_size, hidden_out]");
|
||||
TORCH_CHECK(x.size(1) <= slice_size, "hidden in should be smaller than hidden out");
|
||||
TORCH_CHECK(slice_offset >= 0, "slice offset should be no smaller than 0");
|
||||
TORCH_CHECK((slice_size + slice_offset) <= y.size(1),
|
||||
"slice_size + slice_offset should be smaller than the second dimension of y")
|
||||
|
||||
at::Tensor y_out = y;
|
||||
void* x_ptr = x.data_ptr();
|
||||
void* weight_ptr = weight.data_ptr();
|
||||
void* lora_indices_ptr = lora_indices.data_ptr();
|
||||
void* seq_len_ptr = seq_len.data_ptr();
|
||||
void* y_ptr = y.data_ptr();
|
||||
void* y_out_ptr = y_out.data_ptr();
|
||||
int batch_size = x.size(0);
|
||||
int lora_rank = x.size(1);
|
||||
int output_full_dim = y.size(1);
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("sgmv_expand");
|
||||
cmd.SetCustomHandler([scalar_type, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
batch_size, lora_rank, slice_offset, slice_size, output_full_dim]() -> int {
|
||||
auto dtype = get_dtype_from_torch(scalar_type);
|
||||
int device_id = 0;
|
||||
int64_t aiv_num = 0;
|
||||
TORCH_CHECK(aclGetDeviceCapability(device_id, ACL_DEVICE_INFO_VECTOR_CORE_NUM, &aiv_num) == ACL_SUCCESS);
|
||||
int num_tokens_per_core = (batch_size + aiv_num - 1) / aiv_num;
|
||||
TORCH_CHECK("num_tokens_per_core != 0", "num_tokens_per_core should not be 0");
|
||||
sgmv_expand_impl(dtype, stream, x_ptr, weight_ptr, lora_indices_ptr, seq_len_ptr, y_ptr, y_out_ptr,
|
||||
batch_size, num_tokens_per_core, lora_rank, slice_size, slice_offset, output_full_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return y_out;
|
||||
}
|
||||
} // namespace vllm_ascend
|
||||
|
||||
TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
@ -326,6 +407,14 @@ TORCH_LIBRARY_EXPAND(_C, ops)
|
||||
"bgmv_expand(Tensor! x, Tensor! weight, Tensor! indices, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("bgmv_expand", torch::kPrivateUse1, &vllm_ascend::bgmv_expand);
|
||||
|
||||
ops.def("sgmv_shrink(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y, float scale) -> ()");
|
||||
ops.impl("sgmv_shrink", torch::kPrivateUse1, &vllm_ascend::sgmv_shrink);
|
||||
|
||||
ops.def(
|
||||
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(_C)
|
||||
|
@ -69,6 +69,18 @@ std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask_meta(
|
||||
return {masked_input, mask};
|
||||
}
|
||||
|
||||
at::Tensor bgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &indices, at::Tensor &y,
|
||||
int64_t slice_offset, int64_t slice_size) {
|
||||
at::Tensor y_out = at::empty_like(y);
|
||||
return y_out;
|
||||
}
|
||||
|
||||
at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_indices, at::Tensor &seq_len,
|
||||
at::Tensor &y, int64_t slice_offset, int64_t slice_size) {
|
||||
at::Tensor y_out = at::empty_like(y);
|
||||
return y_out;
|
||||
}
|
||||
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@ -81,6 +93,10 @@ namespace {
|
||||
ops.impl("rotary_embedding", &vllm_ascend::meta::rotary_embedding_meta);
|
||||
// Masked input and mask meta implementation
|
||||
ops.impl("get_masked_input_and_mask", &vllm_ascend::meta::get_masked_input_and_mask_meta);
|
||||
// Bgmv expand
|
||||
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
|
||||
// Sgmv expand
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
|
||||
}
|
||||
}
|
@ -52,9 +52,14 @@ def bgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True):
|
||||
return torch.ops._C.bgmv_expand(inputs, lora_b_weights,
|
||||
lora_indices_tensor, output_tensor,
|
||||
slice_offset, slice_size)
|
||||
return torch.ops._C.bgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
||||
|
||||
def sgmv_shrink(
|
||||
@ -69,11 +74,8 @@ def sgmv_shrink(
|
||||
token_nums: int,
|
||||
scaling: float,
|
||||
):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_shrink(inputs, lora_a_weights, output_tensor, exploded_indices,
|
||||
scaling)
|
||||
return torch.ops._C.sgmv_shrink(inputs, lora_a_weights, lora_indices_tensor,
|
||||
seq_len_tensor, output_tensor, scaling)
|
||||
|
||||
|
||||
def sgmv_expand(inputs: torch.Tensor,
|
||||
@ -86,11 +88,15 @@ def sgmv_expand(inputs: torch.Tensor,
|
||||
max_seq_length: int,
|
||||
token_nums: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
0,
|
||||
output_tensor.size(1),
|
||||
)
|
||||
|
||||
|
||||
def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
@ -105,8 +111,12 @@ def sgmv_expand_slice(inputs: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False):
|
||||
exploded_indices = torch.repeat_interleave(lora_indices_tensor,
|
||||
seq_len_tensor)
|
||||
|
||||
bgmv_expand_slice(inputs, lora_b_weights, output_tensor, exploded_indices,
|
||||
slice_offset, slice_size, add_inputs)
|
||||
return torch.ops._C.sgmv_expand(
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
output_tensor,
|
||||
slice_offset,
|
||||
slice_size
|
||||
)
|
||||
|
@ -22,8 +22,8 @@ from vllm.lora.punica_wrapper.punica_base import PunicaWrapperBase
|
||||
# inherit this class
|
||||
class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
"""
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
PunicaWrapperNPU is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the pytorch punica ops.
|
||||
"""
|
||||
|
||||
@ -130,7 +130,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
add_inputs: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all`
|
||||
computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
"""
|
||||
@ -166,11 +166,11 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
prefill stage, and the `_shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the _shrink_decode function
|
||||
should be called.
|
||||
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
y[i] += (x @ lora_a_stacked[i]) * scale
|
||||
|
||||
|
||||
Args:
|
||||
y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
|
||||
x (torch.Tensor): Input tensor
|
||||
@ -195,19 +195,19 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Performs GEMM and bias addition for multiple slices of lora_b.
|
||||
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_b_stacked)):
|
||||
slice = output_slices[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] +
|
||||
lora_bias_stacked[i]
|
||||
offset += slice
|
||||
|
||||
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor.
|
||||
x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
|
||||
lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]):
|
||||
bias's weight
|
||||
output_slices (Tuple[int, ...]): Every slice's size
|
||||
add_inputs (bool): Defaults to True.
|
||||
@ -266,7 +266,7 @@ class PunicaWrapperNPU(PunicaWrapperBase):
|
||||
buffer: Optional[Tuple[torch.Tensor, ...]] = None,
|
||||
**kwargs) -> None:
|
||||
"""
|
||||
Applicable to linear-related lora.
|
||||
Applicable to linear-related lora.
|
||||
|
||||
Semantics:
|
||||
for i in range(len(lora_a_stacked)):
|
||||
|
@ -80,7 +80,30 @@ def get_masked_input_and_mask_meta(input: torch.Tensor,
|
||||
|
||||
return masked_input, mask
|
||||
|
||||
def bgmv_expand_meta(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
def sgmv_expand_meta(x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
lora_indices: torch.Tensor,
|
||||
seq_len: torch.Tensor,
|
||||
y: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int):
|
||||
|
||||
y_out = torch.empty_like(y)
|
||||
return y_out
|
||||
|
||||
|
||||
register_meta_if_necessary("_C", "rotary_embedding", rotary_embedding_meta)
|
||||
register_meta_if_necessary("_C", "get_masked_input_and_mask",
|
||||
get_masked_input_and_mask_meta)
|
||||
register_meta_if_necessary("_C", "bgmv_expand", bgmv_expand_meta)
|
||||
register_meta_if_necessary("_C", "sgmv_expand", sgmv_expand_meta)
|
||||
|
Reference in New Issue
Block a user