Files
vllm-ascend/csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Chen Chen bcc313e8f2 add mla_preprocess kernel (#3226)
### What this PR does / why we need it?

- Adds the `mla_preprocess` custom kernel to provide an optimized
pre-processing operator for Multi-head Latent Attention (MLA) on Ascend
NPUs.
- Wires the new kernel into the C++ extension pipeline so vLLM can
invoke it directly, cutting Python-side tensor shuffling and memory
copies that previously bottlenecked MLA compilation paths.

### Does this PR introduce any user-facing change?

- No. The change only introduces a low-level kernel; public APIs and
inference behavior remain unchanged.

### How was this patch tested?

- Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.

- vLLM version: v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
2025-10-12 07:39:45 +08:00

2919 lines
153 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#include "kernel/common.h"
#include "kernel/iterator.h"
#include "kernel/mem.h"
#include "kernel/mma.h"
#include "kernel/utils.h"
#include "kernel/simd.h"
#include "kernel/kernel_utils.h"
#include "lib/matmul_intf.h"
#include "mla_preprocess.h"
#include "../op_host/tiling/mla_preprocess_tiling.h"
namespace MLAPO_BF16 {
template <typename QkDtype, typename CosDtype, typename QOutDtype, int8_t CacheMode>
class RopeFp16
{
public:
__aicore__ inline RopeFp16() : blockIdx_(AscendC::GetBlockIdx()) {}
__aicore__ inline void RopeInit(GM_ADDR qGm, AscendC::GlobalTensor<CosDtype> &cosGm,
AscendC::GlobalTensor<CosDtype> &sinGm,
AscendC::GlobalTensor<QOutDtype> &outRopeConcatGm,
AscendC::GlobalTensor<QkDtype> &outRopeConcatGm2, MlaTilingData &ropeConcatParams)
{
qGm_.SetGlobalBuffer(reinterpret_cast<__gm__ QkDtype *>(qGm));
this->cosGm_ = cosGm;
this->sinGm_ = sinGm;
this->outRopeConcatGm_ = outRopeConcatGm;
this->outRopeConcatGm2_ = outRopeConcatGm2;
headDim = ropeConcatParams.headDim;
headNumQ = ropeConcatParams.headNumQ;
rotaryCoeff = ropeConcatParams.rotaryCoeff;
ntokens = ropeConcatParams.ntokens;
realCore = ropeConcatParams.realCore;
nlCoreRun = ropeConcatParams.nlCoreRun;
lCoreRun = ropeConcatParams.lCoreRun;
maxNPerLoopForUb = ropeConcatParams.maxNPerLoopForUb;
preCoreLoopTime = ropeConcatParams.preCoreLoopTime;
preCoreLoopNLast = ropeConcatParams.preCoreLoopNLast;
lastCoreLoopTime = ropeConcatParams.lastCoreLoopTime;
lastCoreLoopNLast = ropeConcatParams.lastCoreLoopNLast;
concatSize = ropeConcatParams.concatSize;
blockIdx_ = (blockIdx_ / 2) * 2 + static_cast<uint64_t>(GetSubBlockidx());
loopTime = (blockIdx_ == realCore - 1) ? lastCoreLoopTime : preCoreLoopTime;
lastLoopN = (blockIdx_ == realCore - 1) ? lastCoreLoopNLast : preCoreLoopNLast;
this->repeatSize_ = 64; // 128 = 256B / sizeof(fp32)
this->rotateStride_ = this->headDim / this->rotaryCoeff;
headBlockLen = static_cast<uint16_t>(this->headDim / ELE_NUM_FP16);
headBlockLenFP32 = static_cast<uint16_t>(this->headDim / ELE_NUM_FP32);
rotaryLen = static_cast<uint16_t>(this->rotateStride_ / ELE_NUM_FP32);
concatBlockLen = static_cast<uint16_t>(this->concatSize / ELE_NUM_FP16);
outLineOffset = this->headDim + this->concatSize;
uint32_t dataNum = this->headDim * this->maxNPerLoopForUb;
dataSizeFp16 = dataNum * sizeof(QkDtype);
dataSizeFp32 = dataNum * sizeof(float);
uint32_t concatDataSize = this->concatSize * sizeof(QkDtype) * this->maxNPerLoopForUb;
}
__aicore__ inline void Process()
{
if (blockIdx_ >= realCore) {
return;
}
uint64_t startCoreLineIndex = this->blockIdx_ * this->nlCoreRun;
// [maxNPerLoopForUb,head_dim] 的 neg
AscendC::LocalTensor<float> negLocal =
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 * 4 + dataSizeFp16 * 3);
ExpandNeg(negLocal, this->maxNPerLoopForUb);
SET_FLAG(MTE3, MTE2, EVENT_ID1);
for (uint32_t zz = 0; zz < this->loopTime; ++zz) {
uint16_t loopN = (zz == this->loopTime - 1) ? this->lastLoopN : this->maxNPerLoopForUb;
uint64_t startHead = startCoreLineIndex + zz * this->maxNPerLoopForUb;
uint64_t endHead = startHead + loopN;
// move in Q
AscendC::LocalTensor<QkDtype> inputQ = buf.GetBuffer<BufferType::ASCEND_UB, QkDtype>(0);
AscendC::LocalTensor<float> inputQCastFP32 = buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp16);
AscendC::LocalTensor<float> reverseQ =
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 + dataSizeFp16);
uint64_t qOffset = startHead * 192 + 128;
CopyQGenReverseQ(inputQ, inputQCastFP32, reverseQ, qOffset, loopN);
// move in cos/sin
AscendC::LocalTensor<QkDtype> inputCos =
buf.GetBuffer<BufferType::ASCEND_UB, QkDtype>(dataSizeFp32 * 2 + dataSizeFp16);
AscendC::LocalTensor<QkDtype> inputSin =
buf.GetBuffer<BufferType::ASCEND_UB, QkDtype>(dataSizeFp32 * 2 + dataSizeFp16 * 2);
uint64_t startSinCosHeadIndex = startHead;
uint64_t headRemain = startHead % this->headNumQ;
uint64_t localStartAddr = 0;
if (headRemain != 0) {
uint64_t preProcessHeadNum = this->headNumQ - headRemain;
uint64_t needToProcesHead = preProcessHeadNum > loopN ? loopN : preProcessHeadNum;
CopyCosSin(inputCos, inputSin, localStartAddr, (startSinCosHeadIndex / this->headNumQ) * this->headDim,
needToProcesHead);
startSinCosHeadIndex += needToProcesHead;
localStartAddr += needToProcesHead * this->headDim;
}
if (startSinCosHeadIndex < endHead) {
uint64_t startSinCosIndex = startSinCosHeadIndex / this->headNumQ;
uint64_t endSinCosIndex = (endHead + this->headNumQ - 1) / this->headNumQ;
for (uint32_t index = startSinCosIndex; index < endSinCosIndex; ++index) {
uint32_t repeatNum =
index == endSinCosIndex - 1 ? endHead - index * this->headNumQ : this->headNumQ;
CopyCosSin(inputCos, inputSin, localStartAddr, index * this->headDim, repeatNum);
localStartAddr += this->headDim * this->headNumQ;
}
}
AscendC::LocalTensor<float> inputCosCastFP32 =
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 * 2 + dataSizeFp16 * 3);
AscendC::LocalTensor<float> inputSinCastFP32 =
buf.GetBuffer<BufferType::ASCEND_UB, float>(dataSizeFp32 * 3 + dataSizeFp16 * 3);
AscendC::Cast(inputCosCastFP32, inputCos, AscendC::RoundMode::CAST_NONE, loopN * this->headDim);
AscendC::Cast(inputSinCastFP32, inputSin, AscendC::RoundMode::CAST_NONE, loopN * this->headDim);
AscendC::PipeBarrier<PIPE_V>();
uint32_t repeatTime = this->headDim * loopN;
AscendC::Mul(inputQCastFP32, inputCosCastFP32, inputQCastFP32, repeatTime);
AscendC::Mul(reverseQ, negLocal, reverseQ, repeatTime);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(reverseQ, inputSinCastFP32, reverseQ, repeatTime);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Add(inputQCastFP32, reverseQ, inputQCastFP32, repeatTime);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(inputQ, inputQCastFP32, AscendC::RoundMode::CAST_RINT, loopN * this->headDim);
AscendC::PipeBarrier<PIPE_V>();
uint64_t outQOffset = startHead * outLineOffset + this->concatSize;
uint64_t outQOffset2 = startHead * this->headDim;
SET_FLAG(V, MTE3, EVENT_ID1);
WAIT_FLAG(V, MTE3, EVENT_ID1);
if constexpr (CacheMode == CACHE_MODE_KVCACHE) {
AscendC::DataCopy(this->outRopeConcatGm_[outQOffset], inputQ, {loopN, headBlockLen, 0, concatBlockLen});
} else {
AscendC::DataCopy(this->outRopeConcatGm2_[outQOffset2], inputQ, loopN * this->headDim);
}
SET_FLAG(MTE3, MTE2, EVENT_ID1);
}
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
}
// tensor -1 -1 -1 1 1 1
template <typename BUF_TYPE>
__aicore__ inline void ExpandNeg(const AscendC::LocalTensor<BUF_TYPE> &tempBuf, uint32_t headNumTemp)
{
for (uint32_t i = 0; i < this->rotateStride_; ++i) {
tempBuf.SetValue(i, (BUF_TYPE)-1);
tempBuf.SetValue(i + this->rotateStride_, (BUF_TYPE)1);
}
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID1);
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID1);
AscendC::Copy(tempBuf[this->headDim], tempBuf, this->headDim, headNumTemp - 1, {1, 1, headBlockLenFP32, 0});
AscendC::PipeBarrier<PIPE_V>();
}
template <typename BUF_TYPE>
__aicore__ inline void CopyQGenReverseQ(const AscendC::LocalTensor<BUF_TYPE> &tempBufQ,
const AscendC::LocalTensor<float> &tempBufQCast,
const AscendC::LocalTensor<float> &tempBufRverseQ, uint64_t qOffset,
uint16_t loopN)
{
// move in Q
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
AscendC::DataCopy(tempBufQ, this->qGm_[qOffset], {loopN, headBlockLen, 128 / 16, 0});
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
// cast fp32
AscendC::Cast(tempBufQCast, tempBufQ, AscendC::RoundMode::CAST_NONE, loopN * this->headDim);
AscendC::PipeBarrier<PIPE_V>();
// move out reverseQ
AscendC::DataCopy(tempBufRverseQ, tempBufQCast[this->rotateStride_], {loopN, rotaryLen, rotaryLen, rotaryLen});
AscendC::DataCopy(tempBufRverseQ[this->rotateStride_], tempBufQCast, {loopN, rotaryLen, rotaryLen, rotaryLen});
AscendC::PipeBarrier<PIPE_V>();
}
template <typename BUF_TYPE>
__aicore__ inline void CopyCosSin(const AscendC::LocalTensor<BUF_TYPE> &tempBufCos,
const AscendC::LocalTensor<BUF_TYPE> &tempBufSin, uint64_t localStartAddr,
uint64_t gmStartAddr, uint64_t repeatNum)
{
AscendC::DataCopy(tempBufCos[localStartAddr], this->cosGm_[gmStartAddr], {1, headBlockLen, 0, 0});
AscendC::DataCopy(tempBufSin[localStartAddr], this->sinGm_[gmStartAddr], {1, headBlockLen, 0, 0});
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
AscendC::Copy(tempBufCos[localStartAddr + this->headDim], tempBufCos[localStartAddr], this->headDim,
repeatNum - 1, {1, 1, headBlockLen, 0});
AscendC::Copy(tempBufSin[localStartAddr + this->headDim], tempBufSin[localStartAddr], this->headDim,
repeatNum - 1, {1, 1, headBlockLen, 0});
AscendC::PipeBarrier<PIPE_V>();
}
private:
AsdopsBuffer<ArchType::ASCEND_V220> buf;
AscendC::GlobalTensor<QkDtype> qGm_;
AscendC::GlobalTensor<CosDtype> cosGm_;
AscendC::GlobalTensor<CosDtype> sinGm_;
AscendC::GlobalTensor<QOutDtype> outRopeConcatGm_;
AscendC::GlobalTensor<QkDtype> outRopeConcatGm2_;
uint32_t repeatSize_{0};
uint32_t rotateStride_{0}; // this->headDim / rope conf
uint32_t headDim;
uint32_t headNumQ;
uint32_t rotaryCoeff;
uint32_t ntokens;
uint32_t realCore;
uint32_t nlCoreRun;
uint32_t lCoreRun;
uint32_t maxNPerLoopForUb;
uint32_t preCoreLoopTime;
uint32_t preCoreLoopNLast;
uint32_t lastCoreLoopTime;
uint32_t lastCoreLoopNLast;
uint32_t concatSize;
uint32_t blockIdx_;
uint32_t loopTime{0};
uint32_t lastLoopN{0};
uint32_t dataSizeFp32;
uint32_t dataSizeFp16;
uint16_t headBlockLen{0};
uint16_t headBlockLenFP32{0};
uint16_t rotaryLen{0};
uint16_t concatBlockLen{0};
uint64_t outLineOffset{0};
};
__aicore__ inline void ReduceSumCustom(const AscendC::LocalTensor<float> &dst_local,
const AscendC::LocalTensor<float> &src_local,
const AscendC::LocalTensor<float> &work_local, int32_t count)
{
#ifdef __DAV_C220_VEC__
uint64_t mask = NUM_PER_REP_FP32;
int32_t repeatTimes = count / NUM_PER_REP_FP32;
int32_t tailCount = count % NUM_PER_REP_FP32;
int32_t bodyCount = repeatTimes * NUM_PER_REP_FP32;
AscendC::BinaryRepeatParams repeatParams;
repeatParams.src0RepStride = AscendC::ONE_REPEAT_BYTE_SIZE / AscendC::ONE_BLK_SIZE;
repeatParams.src0BlkStride = 1;
repeatParams.src1RepStride = 0;
repeatParams.src1BlkStride = 1;
repeatParams.dstRepStride = 0;
repeatParams.dstBlkStride = 1;
Duplicate(work_local, ZERO, NUM_PER_REP_FP32);
AscendC::PipeBarrier<PIPE_V>();
if (likely(repeatTimes > 0)) {
Add(work_local, src_local, work_local, mask, repeatTimes, repeatParams);
AscendC::PipeBarrier<PIPE_V>();
}
if (unlikely(tailCount != 0)) {
Add(work_local, src_local[bodyCount], work_local, tailCount, 1, repeatParams);
AscendC::PipeBarrier<PIPE_V>();
}
AscendC::AscendCUtils::SetMask<float>(NUM_PER_REP_FP32);
cadd_v<ArchType::ASCEND_V220, float>(dst_local, // dst
work_local, // src
1, // repeat
0, // dstRepeatStride
1, // srcBlockStride
0); // srcRepeatStride
AscendC::PipeBarrier<PIPE_V>();
#endif
}
template <typename T, bool WITH_BETA, bool FastComputeMode = false,
QuantMode quantMode = QuantMode::PER_TENSOR_ASYMM_QUANT, bool NEED_DEQUANT = false>
class Quant
{
public:
__aicore__ inline Quant() {}
__aicore__ inline void Init(AscendC::GlobalTensor<T> &quantScaleGmTensor,
AscendC::GlobalTensor<int8_t> &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm,
GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride,
uint32_t num_col, uint64_t gm_offset, uint64_t gm_out_offset, uint32_t row_work_,
const MlaTilingData &mlaParams_)
{
this->quantScaleGmTensor = quantScaleGmTensor;
this->quantOffsetGmTensor = quantOffsetGmTensor;
this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm));
this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm));
if constexpr (!NEED_DEQUANT) {
inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput));
} else {
mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput));
}
outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput));
num_col_ = num_col;
quantMin_ = -128;
this->num_row_ = mlaParams_.n;
this->row_work = row_work;
this->row_work_ = row_work_;
gm_offset_ = gm_offset;
gm_out_offset_ = gm_out_offset;
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
input_stride_ = stride;
num_col_align_withStride_int8 =
(num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
num_col_align_withStride_fp16 =
(num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
num_col_align_withStride_fp32 =
(num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
}
__aicore__ inline void Launch(const AscendC::LocalTensor<int8_t> &dstTensor,
const AscendC::LocalTensor<T> &srcTensor,
const AscendC::LocalTensor<T> &quantScaleTensor,
const AscendC::LocalTensor<int8_t> &quantOffsetTensor,
const AscendC::LocalTensor<float> &res1Tensor,
const AscendC::LocalTensor<float> &res3Tensor)
{
this->dstTensor = dstTensor;
this->srcTensor = srcTensor;
this->fp32_xy = res1Tensor;
this->buf = res3Tensor;
AscendC::LocalTensor<float> g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0
AscendC::LocalTensor<float> sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1
AscendC::LocalTensor<float> work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2
AscendC::LocalTensor<float> abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3
AscendC::LocalTensor<float> sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4
AscendC::LocalTensor<float> max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5
AscendC::LocalTensor<float> perTokenDescaleTensor =
buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6
SET_FLAG(MTE2, V, EVENT_ID1);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
}
if constexpr (NEED_DEQUANT) {
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
}
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
if (std::is_same<T, __bf16>::value) {
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
input_scale_ = 1 / (float)(g.GetValue(0));
input_offset_ = (float)(quantOffsetTensor.GetValue(0));
} else {
SET_FLAG(MTE2, S, EVENT_ID0);
WAIT_FLAG(MTE2, S, EVENT_ID0);
input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0));
input_offset_ = (float)(quantOffsetTensor.GetValue(0));
}
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
}
WAIT_FLAG(MTE2, V, EVENT_ID1);
uint64_t pid = 0;
SET_FLAG(MTE3, MTE2, EVENT_ID0);
while (pid < row_work_) {
uint64_t offset = pid * num_col_;
uint64_t outOffset = pid * (num_col_ - input_stride_);
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
if constexpr (!NEED_DEQUANT) {
AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset],
AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0));
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
} else {
/* Dequant start */
AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset],
AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
num_col_);
SET_FLAG(V, MTE2, EVENT_ID0);
WAIT_FLAG(V, MTE2, EVENT_ID0);
gm_to_ub_align<ArchType::ASCEND_V220, float>(perTokenDescaleTensor, perTokenDescaleGmTensor[pid],
0, // sid
1, // nBurst
sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
SET_FLAG(MTE2, S, EVENT_ID0);
WAIT_FLAG(MTE2, S, EVENT_ID0);
float perTokenDescale = perTokenDescaleTensor.GetValue(0);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
num_col_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT, num_col_);
AscendC::PipeBarrier<PIPE_V>();
}
Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64,
num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM});
AscendC::PipeBarrier<PIPE_V>();
/* Quant start */
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
} else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
ReduceMax(max, abs, work, num_col_ - input_stride_);
AscendC::PipeBarrier<PIPE_V>();
float scaleOut = max.GetValue(0) / 127;
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64,
num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
perTokenDescaleTensor.SetValue(0, scaleOut);
SET_FLAG(S, MTE3, EVENT_ID0);
WAIT_FLAG(S, MTE3, EVENT_ID0);
if constexpr (!NEED_DEQUANT) {
ub_to_gm_align<ArchType::ASCEND_V220, float>(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0,
1, // nBurst
1 * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
} else {
ub_to_gm_align<ArchType::ASCEND_V220, float>(perTokenDescaleGmTensor[num_row_ + pid],
perTokenDescaleTensor, 0,
1, // nBurst
1 * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
}
SET_FLAG(MTE3, V, EVENT_ID0);
WAIT_FLAG(MTE3, V, EVENT_ID0);
}
AscendC::LocalTensor<half> tmpfp16 =
buf.ReinterpretCast<half>()[OFFSET_SUM * num_col_align_withStride_fp32 * 2];
CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32);
AscendC::PipeBarrier<PIPE_V>();
CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(V, MTE3, EVENT_ID0);
AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor,
AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0));
SET_FLAG(MTE3, MTE2, EVENT_ID0);
++pid;
}
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
}
private:
AscendC::LocalTensor<int8_t> dstTensor;
AscendC::LocalTensor<T> srcTensor;
AscendC::LocalTensor<float> fp32_xy;
AscendC::LocalTensor<float> buf;
AscendC::LocalTensor<int32_t> mmTensor;
AscendC::LocalTensor<float> deScaleTensor;
AscendC::GlobalTensor<T> quantScaleGmTensor;
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor;
AscendC::GlobalTensor<T> inputGmTensor;
AscendC::GlobalTensor<int8_t> outputGmTensor;
AscendC::GlobalTensor<float> perTokenDescaleGmTensor;
AscendC::GlobalTensor<float> perChannelDescaleGmTensor;
AscendC::GlobalTensor<int32_t> mmGmTensor;
uint32_t num_col_{0}; // input columns
uint32_t num_row_{0}; // input rows
uint32_t row_work_{0}; // rows need process
uint32_t row_work{0}; // rows need process
uint32_t row_step_{0}; // rows move in once
uint32_t row_tail_{0}; // rows move in last time
uint64_t gm_offset_{0}; // GM data offset
uint64_t gm_out_offset_{0}; // GM data offset
float avg_factor_{1.0}; // 1/num_col_
float input_scale_{1.0};
float input_offset_{0};
int32_t input_stride_{0};
float epsilon_{1e-12f};
uint32_t num_col_align_int8{0};
uint32_t num_col_align_f16{0};
uint32_t num_col_align_f32{0};
uint32_t num_col_align_f32_long{0};
uint32_t num_col_align_withStride_int8{0};
uint32_t num_col_align_withStride_fp16{0};
uint32_t num_col_align_withStride_fp32{0};
uint32_t num_col_temp;
half quantMin_{-128};
uint32_t num_slice_{0};
uint32_t tail_size_{0};
uint32_t tail_copy_{0};
};
template <typename T, bool WITH_BETA, bool FastComputeMode = false,
QuantMode quantMode = QuantMode::PER_TENSOR_ASYMM_QUANT, bool NEED_DEQUANT = false>
class RmsNormQuant
{
public:
__aicore__ inline RmsNormQuant() {}
__aicore__ inline void Init(AscendC::GlobalTensor<T> &gammaGmTensor, AscendC::GlobalTensor<T> &betaGmTensor,
AscendC::GlobalTensor<T> &quantScaleGmTensor,
AscendC::GlobalTensor<int8_t> &quantOffsetGmTensor, GM_ADDR perTokenDescaleGm,
GM_ADDR perChannelDescaleGm, GM_ADDR gmInput, GM_ADDR gmOutput, uint32_t stride,
uint32_t num_col, float avg_factor, uint64_t gm_offset, uint64_t gm_out_offset,
uint32_t row_work_, const MlaTilingData &mlaParams_)
{
this->gammaGmTensor = gammaGmTensor;
this->betaGmTensor = betaGmTensor;
this->quantScaleGmTensor = quantScaleGmTensor;
this->quantOffsetGmTensor = quantOffsetGmTensor;
this->perTokenDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perTokenDescaleGm));
this->perChannelDescaleGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(perChannelDescaleGm));
if constexpr (!NEED_DEQUANT) {
inputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ T *>(gmInput));
} else {
mmGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmInput));
}
outputGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(gmOutput));
num_col_ = num_col;
avg_factor_ = avg_factor;
epsilon_ = 1e-6;
quantMin_ = -128;
this->num_row_ = mlaParams_.n;
this->row_work = row_work;
this->row_work_ = row_work_;
gm_offset_ = gm_offset;
gm_out_offset_ = gm_out_offset;
num_col_align_int8 = (num_col_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
num_col_align_f16 = (num_col_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
num_col_align_f32 = (num_col_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
input_stride_ = stride;
num_col_align_withStride_int8 =
(num_col_ - input_stride_ + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
num_col_align_withStride_fp16 =
(num_col_ - input_stride_ + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
num_col_align_withStride_fp32 =
(num_col_ - input_stride_ + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
}
__aicore__ inline void Launch(const AscendC::LocalTensor<int8_t> &dstTensor,
const AscendC::LocalTensor<T> &srcTensor, const AscendC::LocalTensor<T> &gammaTensor,
const AscendC::LocalTensor<T> &betaTensor,
const AscendC::LocalTensor<T> &quantScaleTensor,
const AscendC::LocalTensor<int8_t> &quantOffsetTensor,
const AscendC::LocalTensor<float> &res1Tensor,
const AscendC::LocalTensor<float> &res3Tensor)
{
this->dstTensor = dstTensor;
this->srcTensor = srcTensor;
this->gammaTensor = gammaTensor;
this->betaTensor = betaTensor;
this->fp32_xy = res1Tensor;
this->buf = res3Tensor;
AscendC::LocalTensor<float> g = buf[OFFSET_GAMMA * num_col_align_withStride_fp32]; // 0
AscendC::LocalTensor<float> sqx = buf[OFFSET_SQX * num_col_align_withStride_fp32]; // 1
AscendC::LocalTensor<float> work = buf[OFFSET_SUM * num_col_align_withStride_fp32]; // 2
AscendC::LocalTensor<float> abs = buf[OFFSET_ABS * num_col_align_withStride_fp32]; // 3
AscendC::LocalTensor<float> sum = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32]; // 4
AscendC::LocalTensor<float> max = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 8]; // 5
AscendC::LocalTensor<float> perTokenDescaleTensor =
buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16]; // 6
AscendC::DataCopy(gammaTensor, gammaGmTensor,
AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0));
AscendC::DataCopy(betaTensor, betaGmTensor,
AscendC::DataCopyParams(1, (num_col_ - input_stride_) / BLOCK_SIZE_16, 0, 0));
SET_FLAG(MTE2, V, EVENT_ID1);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
AscendC::DataCopy(quantScaleTensor, quantScaleGmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
AscendC::DataCopy(quantOffsetTensor, quantOffsetGmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
}
if constexpr (NEED_DEQUANT) {
mmTensor = buf.ReinterpretCast<int32_t>()[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16];
deScaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE];
perTokenDescaleTensor = buf[OFFSET_WORKSPACE_BF16 * num_col_align_withStride_fp32 + 16 + MM1_OUT_SIZE * 2];
AscendC::DataCopy(deScaleTensor, perChannelDescaleGmTensor, AscendC::DataCopyParams(1, num_col_ / 8, 0, 0));
}
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
if (std::is_same<T, __bf16>::value) {
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
Cast(g, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
input_scale_ = 1 / (float)(g.GetValue(0));
input_offset_ = (float)(quantOffsetTensor.GetValue(0));
} else {
SET_FLAG(MTE2, S, EVENT_ID0);
WAIT_FLAG(MTE2, S, EVENT_ID0);
input_scale_ = 1 / (float)(quantScaleTensor.GetValue(0));
input_offset_ = (float)(quantOffsetTensor.GetValue(0));
}
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
}
WAIT_FLAG(MTE2, V, EVENT_ID1);
Cast(buf[OFFSET_GAMMA * num_col_align_withStride_fp32], gammaTensor, AscendC::RoundMode::CAST_NONE,
REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM});
AscendC::PipeBarrier<PIPE_V>();
uint64_t pid = 0;
SET_FLAG(MTE3, MTE2, EVENT_ID0);
while (pid < row_work_) {
uint64_t offset = pid * num_col_;
uint64_t outOffset = pid * (num_col_ - input_stride_);
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
if constexpr (!NEED_DEQUANT) {
AscendC::DataCopy(srcTensor, inputGmTensor[gm_offset_ + offset],
AscendC::DataCopyParams(1, num_col_ / BLOCK_SIZE_16, 0, 0));
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
} else {
/* Dequant start */
AscendC::DataCopy(mmTensor, mmGmTensor[gm_offset_ + offset],
AscendC::DataCopyParams(1, num_col_ / 8, 0, 0)); // 2112
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE, num_col_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
num_col_);
SET_FLAG(V, MTE2, EVENT_ID0);
WAIT_FLAG(V, MTE2, EVENT_ID0);
gm_to_ub_align<ArchType::ASCEND_V220, float>(perTokenDescaleTensor, perTokenDescaleGmTensor[pid],
0, // sid
1, // nBurst
sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
SET_FLAG(MTE2, S, EVENT_ID0);
WAIT_FLAG(MTE2, S, EVENT_ID0);
float perTokenDescale = perTokenDescaleTensor.GetValue(0);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
num_col_);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT, num_col_);
AscendC::PipeBarrier<PIPE_V>();
}
Cast(fp32_xy, srcTensor[input_stride_], AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64,
num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM});
AscendC::PipeBarrier<PIPE_V>();
Mul(sqx, fp32_xy, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE,
AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
Muls(sqx, sqx, avg_factor_, num_col_ - input_stride_);
AscendC::PipeBarrier<PIPE_V>();
ReduceSumCustom(sum, sqx, work, num_col_ - input_stride_);
AscendC::PipeBarrier<PIPE_V>();
Adds(sum, sum, epsilon_, 1);
AscendC::PipeBarrier<PIPE_V>();
Sqrt(sum, sum, 1);
SET_FLAG(V, S, EVENT_ID0);
WAIT_FLAG(V, S, EVENT_ID0);
float factor = 1 / sum.GetValue(0);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
Muls(fp32_xy, fp32_xy, factor, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
Mul(fp32_xy, fp32_xy, g, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE,
AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
if constexpr (WITH_BETA) {
AscendC::LocalTensor<T> b = this->betaTensor;
Cast(work, b, AscendC::RoundMode::CAST_NONE, REPEAT_TIME_64,
num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE / OFFSET_SUM});
AscendC::PipeBarrier<PIPE_V>();
Add(fp32_xy, fp32_xy, work, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE,
AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
}
/* Quant start */
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
Muls(fp32_xy, fp32_xy, input_scale_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
Adds(fp32_xy, fp32_xy, input_offset_, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
} else if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
Abs(abs, fp32_xy, REPEAT_TIME_64, num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
ReduceMax(max, abs, work, num_col_ - input_stride_);
AscendC::PipeBarrier<PIPE_V>();
float scaleOut = max.GetValue(0) / 127;
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
Muls(fp32_xy, fp32_xy, (float)(1 / scaleOut), REPEAT_TIME_64,
num_col_align_withStride_fp32 / REPEAT_TIME_64,
{1, 1, AscendC::DEFAULT_REPEAT_STRIDE, AscendC::DEFAULT_REPEAT_STRIDE});
AscendC::PipeBarrier<PIPE_V>();
perTokenDescaleTensor.SetValue(0, scaleOut);
SET_FLAG(S, MTE3, EVENT_ID0);
WAIT_FLAG(S, MTE3, EVENT_ID0);
if constexpr (!NEED_DEQUANT) {
ub_to_gm_align<ArchType::ASCEND_V220, float>(perTokenDescaleGmTensor[pid], perTokenDescaleTensor, 0,
1, // nBurst
1 * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
} else {
ub_to_gm_align<ArchType::ASCEND_V220, float>(perTokenDescaleGmTensor[num_row_ + pid],
perTokenDescaleTensor, 0,
1, // nBurst
1 * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
}
SET_FLAG(MTE3, V, EVENT_ID0);
WAIT_FLAG(MTE3, V, EVENT_ID0);
}
AscendC::LocalTensor<half> tmpfp16 =
buf.ReinterpretCast<half>()[OFFSET_SUM * num_col_align_withStride_fp32 * 2];
CastFrom32To16(tmpfp16, fp32_xy, num_col_align_withStride_fp32);
AscendC::PipeBarrier<PIPE_V>();
CastFromF16ToI8(dstTensor, tmpfp16, quantMin_, num_col_align_withStride_fp16);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(V, MTE3, EVENT_ID0);
AscendC::DataCopy(outputGmTensor[gm_out_offset_ + outOffset], dstTensor,
AscendC::DataCopyParams(1, (num_col_ - input_stride_) / 32, 0, 0));
SET_FLAG(MTE3, MTE2, EVENT_ID0);
++pid;
}
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
}
private:
AscendC::LocalTensor<int8_t> dstTensor;
AscendC::LocalTensor<T> srcTensor;
AscendC::LocalTensor<T> gammaTensor;
AscendC::LocalTensor<T> betaTensor;
AscendC::LocalTensor<float> fp32_xy;
AscendC::LocalTensor<float> buf;
AscendC::LocalTensor<int32_t> mmTensor;
AscendC::LocalTensor<float> deScaleTensor;
AscendC::GlobalTensor<T> gammaGmTensor;
AscendC::GlobalTensor<T> betaGmTensor;
AscendC::GlobalTensor<T> quantScaleGmTensor;
AscendC::GlobalTensor<int8_t> quantOffsetGmTensor;
AscendC::GlobalTensor<T> inputGmTensor;
AscendC::GlobalTensor<int8_t> outputGmTensor;
AscendC::GlobalTensor<float> perTokenDescaleGmTensor;
AscendC::GlobalTensor<float> perChannelDescaleGmTensor;
AscendC::GlobalTensor<int32_t> mmGmTensor;
uint32_t num_col_{0};
uint32_t num_row_{0};
uint32_t row_work_{0};
uint32_t row_work{0};
uint32_t row_step_{0};
uint32_t row_tail_{0};
uint64_t gm_offset_{0};
uint64_t gm_out_offset_{0};
float avg_factor_{1.0};
float input_scale_{1.0};
float input_offset_{0};
int32_t input_stride_{0};
float epsilon_{1e-12f};
uint32_t num_col_align_int8{0};
uint32_t num_col_align_f16{0};
uint32_t num_col_align_f32{0};
uint32_t num_col_align_f32_long{0};
uint32_t num_col_align_withStride_int8{0};
uint32_t num_col_align_withStride_fp16{0};
uint32_t num_col_align_withStride_fp32{0};
uint32_t num_col_temp;
half quantMin_{-128};
uint32_t num_slice_{0};
uint32_t tail_size_{0};
uint32_t tail_copy_{0};
};
template <typename InDtype, typename ScaleDtype>
class EinSumQuant
{
public:
__aicore__ explicit EinSumQuant() {}
__aicore__ __force_inline__ void Init(GM_ADDR einSumOutGm, GM_ADDR scaleGm, GM_ADDR quantOutGm,
const MlaTilingData &tilingData)
{
einSumOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(einSumOutGm));
scaleGm_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(scaleGm));
quantOutGm_.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOutGm));
headNum = tilingData.esqHeadNum;
colNum = tilingData.esqColNum;
ubHeadLoop = tilingData.esqUbHeadLoop;
headPerLoop = tilingData.esqHeadPerLoop;
headTail = tilingData.esqHeadTail;
colLoop = tilingData.esqColLoop;
colTail = tilingData.esqColTail;
currentIdx = (AscendC::GetBlockIdx() / 2) * 2 + GetSubBlockidx();
if (currentIdx < tilingData.esqFrontCore) {
batchNum = tilingData.esqFrontCoreBatch;
currentCoreStartOffset = currentIdx * tilingData.esqFrontCoreBatch * headNum * colNum;
} else {
batchNum = tilingData.esqTailCoreBatch;
currentCoreStartOffset = (tilingData.esqFrontCore * tilingData.esqFrontCoreBatch +
(currentIdx - tilingData.esqFrontCore) * tilingData.esqTailCoreBatch) *
headNum * colNum;
}
calcRepeatStride = static_cast<uint8_t>(colNum / ELE_NUM_FP32);
padLen = RoundUp(headNum, ELE_NUM_FP16);
calcLength = headPerLoop * colNum;
// calc tensors' data size(bytes) and block
scaleBrcbFp32DataSize = padLen * ELE_NUM_FP32 * sizeof(float);
inputDataSize = calcLength * sizeof(InDtype);
inputDataBlock = calcLength * sizeof(InDtype) / BLOCK_SIZE_32;
inputFp32DataSize = calcLength * sizeof(float);
int8OutDataBlcok = calcLength / BLOCK_SIZE_32;
headTailDataBlock = headTail * colNum * sizeof(InDtype) / BLOCK_SIZE_32;
int8TailOutDataBlock = headTail * colNum / BLOCK_SIZE_32;
if (padLen > headNum) {
scaleCopyParams = AscendC::DataCopyExtParams(1, static_cast<uint32_t>(headNum * sizeof(InDtype)), 0, 0, 0);
scalePadParams = AscendC::DataCopyPadExtParams<InDtype>(true, 0, static_cast<uint8_t>(padLen - headNum), 0);
}
}
__aicore__ __force_inline__ void Process()
{
if (batchNum == 0) {
return;
}
// init local tensor
scaleBrcbFp32_ = buf.GetBuffer<BufferType::ASCEND_UB, float>(0);
inputTensor_ = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(scaleBrcbFp32DataSize);
inputFp32_ =
buf.GetBuffer<BufferType::ASCEND_UB, float>(scaleBrcbFp32DataSize + inputDataSize * ROPE_CONCAT_NUM_BUFFER);
int8OutTensor_ = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
scaleBrcbFp32DataSize + (inputDataSize + inputFp32DataSize) * ROPE_CONCAT_NUM_BUFFER);
// scale copy in, cast, brcb[H, 1] --> [H, 8], use input ub space
if (headNum == padLen) {
AscendC::DataCopy(inputTensor_, scaleGm_, headNum);
} else {
AscendC::DataCopyPad(inputTensor_, scaleGm_, scaleCopyParams, scalePadParams);
}
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
AscendC::Cast(inputFp32_, inputTensor_, AscendC::RoundMode::CAST_NONE, padLen);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Brcb(scaleBrcbFp32_, inputFp32_, padLen / ELE_NUM_FP32, {1, 8});
AscendC::PipeBarrier<PIPE_V>();
uint8_t pingFlag = 0;
// batch Loop
SET_FLAG(V, MTE2, EVENT_ID0); // input copy in wait vector release ub
SET_FLAG(V, MTE2, EVENT_ID1);
SET_FLAG(MTE3, V, EVENT_ID0); // quant calc wait last result copyout
SET_FLAG(MTE3, V, EVENT_ID1);
for (uint32_t batchIdx = 0; batchIdx < batchNum; batchIdx++) {
batchOffset = batchIdx * headNum * colNum;
// ub Loop
for (uint32_t ubLoopIdx = 0; ubLoopIdx < ubHeadLoop; ubLoopIdx++) {
scaleBrcbOffset = ubLoopIdx * headPerLoop * ELE_NUM_FP32;
inputLoopOffset = ubLoopIdx * headPerLoop * colNum;
calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset;
calcTmpOffset = pingFlag * calcLength;
// input CopyIn and Cast
WAIT_FLAG(V, MTE2, pingFlag);
AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset],
{1, inputDataBlock, 0, 0});
SET_FLAG(MTE2, V, pingFlag);
WAIT_FLAG(MTE2, V, pingFlag);
AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE,
calcLength);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE2, pingFlag);
// quant calc
for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) {
colOffset = colIdx * CONST_64;
AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset],
scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headPerLoop,
{1, 1, 0, calcRepeatStride, calcRepeatStride, 1});
}
AscendC::PipeBarrier<PIPE_V>();
// quant fp32 --> fp16 --> int8
CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast<half>(), inputFp32_[calcTmpOffset],
calcLength);
AscendC::PipeBarrier<PIPE_V>();
WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out
CastFromF16ToI8(int8OutTensor_[calcTmpOffset],
inputFp32_[calcTmpOffset].template ReinterpretCast<half>(), quantMin_, calcLength);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE3, pingFlag);
WAIT_FLAG(V, MTE3, pingFlag);
// int8 CopyOut
AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset],
{1, int8OutDataBlcok, 0, 0});
SET_FLAG(MTE3, V, pingFlag);
pingFlag = 1 - pingFlag;
}
// deal with head tail
if (headTail > 0) {
scaleBrcbOffset = ubHeadLoop * headPerLoop * ELE_NUM_FP32;
inputLoopOffset = ubHeadLoop * headPerLoop * colNum;
calcStartOffset = currentCoreStartOffset + batchOffset + inputLoopOffset;
calcTmpOffset = pingFlag * calcLength;
// input CopyIn and Cast
WAIT_FLAG(V, MTE2, pingFlag);
AscendC::DataCopy(inputTensor_[calcTmpOffset], einSumOutGm_[calcStartOffset],
{1, headTailDataBlock, 0, 0});
SET_FLAG(MTE2, V, pingFlag);
WAIT_FLAG(MTE2, V, pingFlag);
AscendC::Cast(inputFp32_[calcTmpOffset], inputTensor_[calcTmpOffset], AscendC::RoundMode::CAST_NONE,
headTail * colNum);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE2, pingFlag);
// quant calc
for (uint32_t colIdx = 0; colIdx < colLoop; colIdx++) {
colOffset = colIdx * CONST_64;
AscendC::Mul(inputFp32_[calcTmpOffset + colOffset], inputFp32_[calcTmpOffset + colOffset],
scaleBrcbFp32_[scaleBrcbOffset], CONST_64, headTail,
{1, 1, 0, calcRepeatStride, calcRepeatStride, 1});
}
AscendC::PipeBarrier<PIPE_V>();
// quant fp32 --> fp16 --> int8
CastFrom32To16(inputFp32_[calcTmpOffset].template ReinterpretCast<half>(), inputFp32_[calcTmpOffset],
headTail * colNum);
AscendC::PipeBarrier<PIPE_V>();
WAIT_FLAG(MTE3, V, pingFlag); // wait last result copy out
CastFromF16ToI8(int8OutTensor_[calcTmpOffset],
inputFp32_[calcTmpOffset].template ReinterpretCast<half>(), quantMin_,
headTail * colNum);
AscendC::PipeBarrier<PIPE_V>();
SET_FLAG(V, MTE3, pingFlag);
WAIT_FLAG(V, MTE3, pingFlag);
// int8 CopyOut
AscendC::DataCopy(quantOutGm_[calcStartOffset], int8OutTensor_[calcTmpOffset],
{1, int8TailOutDataBlock, 0, 0});
SET_FLAG(MTE3, V, pingFlag);
pingFlag = 1 - pingFlag;
}
}
WAIT_FLAG(V, MTE2, EVENT_ID0);
WAIT_FLAG(V, MTE2, EVENT_ID1);
WAIT_FLAG(MTE3, V, EVENT_ID0);
WAIT_FLAG(MTE3, V, EVENT_ID1);
}
private:
AsdopsBuffer<ArchType::ASCEND_V220> buf;
AscendC::GlobalTensor<InDtype> einSumOutGm_;
AscendC::GlobalTensor<ScaleDtype> scaleGm_;
AscendC::GlobalTensor<int8_t> quantOutGm_;
AscendC::LocalTensor<float> scaleBrcbFp32_;
AscendC::LocalTensor<InDtype> inputTensor_;
AscendC::LocalTensor<float> inputFp32_;
AscendC::LocalTensor<int8_t> int8OutTensor_;
AscendC::DataCopyExtParams scaleCopyParams;
AscendC::DataCopyPadExtParams<InDtype> scalePadParams;
// data processed by a single core[batchNum, headNum, colNum]
uint32_t batchNum; // The number of batches per kernel processed
uint32_t headNum;
uint32_t colNum; // Number of columns per row
// ub loop
uint32_t ubHeadLoop; // The number of times the UB loops through the head.
uint32_t headPerLoop; // The number of heads processed per UB cycle
uint32_t headTail; // The number of heads last processed
// col loop
uint32_t colLoop; // The number of calculations in the column direction cycle.
uint32_t colTail; // The number of cols last processed
uint32_t currentIdx;
uint64_t currentCoreStartOffset;
uint32_t inputDataSize; // The size of each carrybytes
uint32_t inputFp32DataSize;
uint32_t scaleBrcbFp32DataSize;
uint16_t inputDataBlock; // The number of blocks brought in per movebytes
uint16_t int8OutDataBlcok;
uint16_t headTailDataBlock;
uint16_t int8TailOutDataBlock;
// gm offset
uint64_t inputLoopOffset{0};
uint64_t batchOffset{0};
uint64_t calcStartOffset{0};
// double buffer tmp tensor length
uint32_t scaleBrcbOffset{0};
uint32_t calcLength{0};
uint32_t calcTmpOffset{0};
half quantMin_{-128};
uint32_t colOffset{0};
uint32_t padLen;
uint8_t calcRepeatStride;
};
#ifdef __DAV_C220_CUBE__
struct MatCoord {
uint64_t m{0};
uint64_t k{0};
uint64_t n{0};
};
template <typename InDtype, typename OutDtype, DataFormat formatB, bool transB, uint32_t swizzleDirect,
uint64_t splitGapA, uint64_t splitGapC>
class PpMatmulEinSum
{
using AccumDtype = float;
template <DataFormat srcFormat, DataFormat dstFormat>
using CopyGmToCbuf = gm_to_l1<ArchType::ASCEND_V220, InDtype, srcFormat, dstFormat>;
using LoadCbufToCa = l1_to_l0_a<ArchType::ASCEND_V220, InDtype, false, DataFormat::ZN, DataFormat::ZZ>;
using LoadCbufToCb = l1_to_l0_b<ArchType::ASCEND_V220, InDtype, transB, DataFormat::ZN, DataFormat::NZ>;
using Mad = mmad<ArchType::ASCEND_V220, InDtype, InDtype, AccumDtype, false>;
using CopyCcToGm = l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, OutDtype, AccumDtype>;
static constexpr uint32_t L0_PINGPONG_BUFFER_LEN = 16384;
static constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 131072;
static constexpr uint32_t CONST_16 = 16;
static constexpr uint32_t CONST_256 = 256;
public:
__aicore__ explicit PpMatmulEinSum(){};
__aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, const MlaTilingData &mlaParams)
{
#ifdef __DAV_C220_CUBE__
batch_size = mlaParams.mm3.numBatch;
m = mlaParams.mm3.m;
k = mlaParams.mm3.k;
n = mlaParams.mm3.n;
m0 = mlaParams.mm3.m0;
k0 = mlaParams.mm3.k0;
n0 = mlaParams.mm3.n0;
tdim.m = mlaParams.mm3.mLoop;
tdim.k = mlaParams.mm3.kLoop;
tdim.n = mlaParams.mm3.nLoop;
core_loop = mlaParams.mm3.coreLoop;
swizzle_cnt = mlaParams.mm3.swizzleCount;
num_core = mlaParams.mm3.blockDim;
core_idx = AscendC::GetBlockIdx();
ping_flag = 1;
gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA));
gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB));
gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC));
AsdopsBuffer<ArchType::ASCEND_V220> buf;
l1_base_a = buf.GetBuffer<BufferType::ASCEND_CB, InDtype>(0);
l1_base_b = buf.GetBuffer<BufferType::ASCEND_CB, InDtype>(RoundUp<CONST_256>(m0 * k0 * sizeof(InDtype)));
l0a_base = buf.GetBuffer<BufferType::ASCEND_L0A, InDtype>(0);
l0b_base = buf.GetBuffer<BufferType::ASCEND_L0B, InDtype>(0);
#endif
return;
}
__aicore__ __force_inline__ void Process()
{
#ifdef __DAV_C220_CUBE__
if (block_idx >= num_core) {
WaitFlagDev(AIC_MM3_START);
return;
}
using LocalTensor = AscendC::LocalTensor<InDtype>;
SET_FLAG(MTE1, MTE2, EVENT_ID0);
SET_FLAG(MTE1, MTE2, EVENT_ID1);
SET_FLAG(MTE1, MTE2, EVENT_ID2);
SET_FLAG(MTE1, MTE2, EVENT_ID3);
SET_FLAG(FIX, M, EVENT_ID0);
SET_FLAG(M, MTE1, EVENT_ID0);
SET_FLAG(M, MTE1, EVENT_ID1);
for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += num_core) {
uint64_t batch_idx = loop_idx / tdim.n / tdim.m;
MatCoord tidx{0};
GetBaseBlockIdx(loop_idx, tidx);
uint64_t offset_c = tidx.m * m0 * batch_size * (n + splitGapC) + batch_idx * (n + splitGapC) + tidx.n * n0;
uint64_t m_actual = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0;
uint64_t n_actual = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0;
uint64_t m_round = RoundUp<CONST_16>(m_actual);
uint64_t n_round = RoundUp<CONST_16>(n_actual);
uint64_t mn_max = m_round > n_round ? m_round : n_round;
uint64_t k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / CONST_16 * CONST_16;
uint64_t shuffle_k = en_shuffle_k ? (core_idx % tdim.k) : 0;
uint64_t k_actual = (shuffle_k == tdim.k - 1) ? k - shuffle_k * k0 : k0;
uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16;
LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
event_t event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
if (loop_idx == core_idx) {
WaitFlagDev(AIC_MM3_START);
// Copy A from gm to l1 buffer
uint64_t offset_a = GetOffsetA(batch_idx, tidx.m, shuffle_k);
WAIT_FLAG(MTE1, MTE2, event_id);
CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round);
SET_FLAG(MTE2, MTE1, event_id);
// Copy B from gm to l1 buffer
uint64_t offset_b = GetOffsetB(batch_idx, shuffle_k, tidx.n);
WAIT_FLAG(MTE1, MTE2, event_id + 2);
CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round);
SET_FLAG(MTE2, MTE1, event_id + 2);
}
for (tidx.k = 0; tidx.k < tdim.k; ++tidx.k) {
shuffle_k = en_shuffle_k ? (tidx.k + core_idx) % tdim.k : tidx.k;
uint64_t k_actual = (shuffle_k == (tdim.k - 1)) ? (k - shuffle_k * k0) : k0;
uint64_t k_round = (k_actual + CONST_16 - 1) / CONST_16 * CONST_16;
fdim.k = (k_actual + k_part_len - 1) / k_part_len;
LocalTensor l1_buf_a = ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
if (tidx.k < tdim.k - 1) {
uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + tidx.k + 1) % tdim.k : (tidx.k + 1);
uint64_t offset_a_next = GetOffsetA(batch_idx, tidx.m, shuffle_k_next);
uint64_t offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, tidx.n);
uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0;
uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
// Preload A from gm to l1 buffer.
WAIT_FLAG(MTE1, MTE2, event_id_next);
CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next);
SET_FLAG(MTE2, MTE1, event_id_next);
// Preload B from gm to l1 buffer.
WAIT_FLAG(MTE1, MTE2, event_id_next + 2);
CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round);
SET_FLAG(MTE2, MTE1, event_id_next + 2);
}
if (tidx.k == tdim.k - 1 && loop_idx + num_core < core_loop) {
uint64_t b_idx_next = (loop_idx + num_core) / tdim.n / tdim.m;
MatCoord tidx{0};
GetBaseBlockIdx(loop_idx + num_core, tidx);
uint64_t shuffle_k_next = en_shuffle_k ? (core_idx % tdim.k) : 0;
uint64_t m_actual_next = (tidx.m == (tdim.m - 1)) ? (m - tidx.m * m0) : m0;
uint64_t n_actual_next = (tidx.n == (tdim.n - 1)) ? (n - tidx.n * n0) : n0;
uint64_t m_round_next = (m_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
uint64_t n_round_next = (n_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
uint64_t k_actual_next = (shuffle_k_next == (tdim.k - 1)) ? (k - shuffle_k_next * k0) : k0;
uint64_t k_round_next = (k_actual_next + CONST_16 - 1) / CONST_16 * CONST_16;
uint64_t offset_a_next = GetOffsetA(b_idx_next, tidx.m, shuffle_k_next);
uint64_t offset_b_next = GetOffsetB(b_idx_next, shuffle_k_next, tidx.n);
LocalTensor l1_buf_a_next = (1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN];
LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[L1_PINGPONG_BUFFER_LEN];
event_t event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
// Preload A from gm to l1 buffer.
WAIT_FLAG(MTE1, MTE2, event_id_next);
CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual_next, m_round_next, k_actual_next,
k_round_next);
SET_FLAG(MTE2, MTE1, event_id_next);
// Preload B from gm to l1 buffer.
WAIT_FLAG(MTE1, MTE2, event_id_next + 2);
CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual_next,
n_round_next);
SET_FLAG(MTE2, MTE1, event_id_next + 2);
}
MatCoord fidx{0};
for (fidx.k = 0; fidx.k < fdim.k; ++fidx.k) {
uint32_t k0_round = (fidx.k < fdim.k - 1) ? k_part_len : k_round - fidx.k * k_part_len;
uint32_t k0_actual = (fidx.k < fdim.k - 1) ? k_part_len : k_actual - fidx.k * k_part_len;
auto mte1_mad_ping_flag = 1 - fidx.k % 2;
auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1;
LocalTensor l0a_buf = l0a_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN];
LocalTensor l0b_buf = l0b_base[(fidx.k & 0b1) * L0_PINGPONG_BUFFER_LEN];
// *** load matrix A from L1 to L0A
if (fidx.k == 0) {
WAIT_FLAG(MTE2, MTE1, event_id);
}
WAIT_FLAG(M, MTE1, mte1_mad_event_id);
if ((m == 1) || (m_actual == 1)) {
l1_to_l0_a<ArchType::ASCEND_V220, InDtype, false, DataFormat::VECTOR, DataFormat::VECTOR>(
l0a_buf, // dst
l1_buf_a[fidx.k * k_part_len], // src
0, // mTileCeil
CeilDiv<CONST_256>(k0_round), // kPartCeil
0, // mSrcStride
1, // kSrcStride
0, // mDstStride
0); // kDstStride
} else {
LoadCbufToCa(l0a_buf, // l0Tensor
l1_buf_a[fidx.k * k_part_len * m_round], // l1Tensor
m_round, // mTileCeil
k0_round, // kPartCeil
1, // mSrcStride
m_round / CONST_16, // kSrcStride
k0_round / CONST_16, // mDstStride
1); // kDstStride
}
if (fidx.k == fdim.k - 1) {
SET_FLAG(MTE1, MTE2, event_id);
}
// *** load matrix B from L1 to L0B
if (fidx.k == 0) {
WAIT_FLAG(MTE2, MTE1, event_id + 2);
}
if constexpr (transB) {
LoadCbufToCb(l0b_buf, // l0Tensor
l1_buf_b[fidx.k * k_part_len * n_round], // l1Tensor
n_round, // nTileCeil
k0_round, // kPartCeil
1, // nSrcStride
n_round / CONST_16, // kSrcStride
1, // nDstStride
k0_round / CONST_16); // kDstStride
} else {
LoadCbufToCb(l0b_buf, // l0Tensor
l1_buf_b[fidx.k * k_part_len * CONST_16], // l1Tensor
n_round, // nTileCeil
k0_round, // kPartCeil
k_round / CONST_16, // nSrcStride
1, // kSrcStride
1, // nDstStride
n_round / CONST_16); // kDstStride
}
if (fidx.k == fdim.k - 1) {
SET_FLAG(MTE1, MTE2, event_id + 2);
}
SET_FLAG(MTE1, M, mte1_mad_event_id);
WAIT_FLAG(MTE1, M, mte1_mad_event_id);
bool init_c = (tidx.k == 0 && fidx.k == 0);
if (init_c) {
WAIT_FLAG(FIX, M, EVENT_ID0);
}
Mad(l0c_buf, // c
l0a_buf, // a
l0b_buf, // b
m_actual, // mTileActual
n_actual, // nTileActual
k0_actual, // kTileActual
init_c); // initC
PIPE_BARRIER(M);
SET_FLAG(M, MTE1, mte1_mad_event_id);
}
ping_flag = 1 - ping_flag;
}
SET_FLAG(M, FIX, EVENT_ID0);
WAIT_FLAG(M, FIX, EVENT_ID0);
// copy from L0C to gm
CopyCcToGm(gm_c[offset_c], // dst
l0c_buf, // src
m_actual, // mTileActual
n_actual, // nTileActual
m_round, // mTileCeil
(n + splitGapC) * batch_size); // nActual
SET_FLAG(FIX, M, EVENT_ID0);
}
WAIT_FLAG(M, MTE1, EVENT_ID0);
WAIT_FLAG(M, MTE1, EVENT_ID1);
WAIT_FLAG(MTE1, MTE2, EVENT_ID0);
WAIT_FLAG(MTE1, MTE2, EVENT_ID1);
WAIT_FLAG(MTE1, MTE2, EVENT_ID2);
WAIT_FLAG(MTE1, MTE2, EVENT_ID3);
WAIT_FLAG(FIX, M, EVENT_ID0);
#endif
}
private:
__aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, MatCoord &tidx)
{
uint64_t in_batch_idx = index % (tdim.m * tdim.n);
if constexpr (swizzleDirect == 0) { // Zn
uint64_t tile_block_loop = (tdim.m + swizzle_cnt - 1) / swizzle_cnt;
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.n);
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.n);
uint64_t n_row = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_row = tdim.m - swizzle_cnt * tile_block_idx;
}
tidx.m = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row;
tidx.n = in_tile_block_idx / n_row;
if (tile_block_idx % 2 != 0) {
tidx.n = tdim.n - tidx.n - 1;
}
} else if constexpr (swizzleDirect == 1) { // Nz
uint64_t tile_block_loop = (tdim.n + swizzle_cnt - 1) / swizzle_cnt;
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * tdim.m);
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * tdim.m);
uint64_t n_col = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_col = tdim.n - swizzle_cnt * tile_block_idx;
}
tidx.m = in_tile_block_idx / n_col;
tidx.n = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col;
if (tile_block_idx % 2 != 0) {
tidx.m = tdim.m - tidx.m - 1;
}
}
return;
}
__aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t bIdx, const uint64_t mIdx, const uint64_t kIdx)
{
return mIdx * m0 * batch_size * (k + splitGapA) + bIdx * (k + splitGapA) + kIdx * k0;
}
__aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t bIdx, const uint64_t kIdx, const uint64_t nIdx)
{
if constexpr (formatB == DataFormat::ND) {
if constexpr (transB) {
return bIdx * k * n + nIdx * n0 * k + kIdx * k0;
} else {
return bIdx * k * n + kIdx * k0 * n + nIdx * n0;
}
} else {
if constexpr (transB) {
return bIdx * RoundUp<CONST_16>(n) * RoundUp<CONST_16>(k) + kIdx * k0 * RoundUp<CONST_16>(n) +
nIdx * n0 * CONST_16;
} else {
return bIdx * RoundUp<CONST_16>(k) * RoundUp<CONST_16>(n) + nIdx * n0 * RoundUp<CONST_16>(k) +
kIdx * k0 * CONST_16;
}
}
}
__aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor<InDtype> &dstTensor,
const AscendC::GlobalTensor<InDtype> &srcTensor, const uint64_t m_actual,
const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round)
{
if ((m == 1) || (m_actual == 1)) {
CopyGmToCbuf<DataFormat::ND, DataFormat::ND>(dstTensor, // dst
srcTensor, // src
1, // nTileActual
CONST_16, // nTileCeil
1, // nVal
k_actual, // kTileActual
k_round, // kTileCeil
k); // dVal
} else {
CopyGmToCbuf<DataFormat::ND, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
m_actual, // nTileActual
m_round, // nTileCeil
m, // nVal
k_actual, // dTileActual
k_round, // dTileCeil
(k + splitGapA) * batch_size); // dVal
}
}
__aicore__ __force_inline__ void CopyTileB(AscendC::LocalTensor<InDtype> &dstTensor,
const AscendC::GlobalTensor<InDtype> &srcTensor, const uint64_t k_actual,
const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round)
{
if constexpr (formatB == DataFormat::ND) {
if constexpr (transB) {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
n_actual, // nTileActual
n_round, // nTileCeil
n, // nVal
k_actual, // dTileActual
k_round, // dTileCeil
k); // dVal
} else {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
k_actual, // nTileActual
k_round, // nTileCeil
k, // nVal
n_actual, // dTileActual
n_round, // dTileCeil
n); // dVal
}
} else {
if constexpr (transB) {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
n_actual, // nTileActual
n_round, // nTileCeil
RoundUp<CONST_16>(n), // nVal
k_actual, // dTileActual
k_round, // dTileCeil
RoundUp<CONST_16>(k)); // dVal
} else {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
k_actual, // nTileActual
k_round, // nTileCeil
RoundUp<CONST_16>(k), // nVal
n_actual, // dTileActual
n_round, // dTileCeil
RoundUp<CONST_16>(n)); // dVal
}
}
}
private:
AscendC::GlobalTensor<InDtype> gm_a;
AscendC::GlobalTensor<InDtype> gm_b;
AscendC::GlobalTensor<OutDtype> gm_c;
AscendC::LocalTensor<InDtype> l1_base_a;
AscendC::LocalTensor<InDtype> l1_base_b;
AscendC::LocalTensor<InDtype> l0a_base;
AscendC::LocalTensor<InDtype> l0b_base;
AscendC::LocalTensor<float> l0c_buf;
uint32_t num_core{0};
uint32_t batch_size{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
MatCoord tdim{0};
MatCoord fdim{0};
uint32_t core_loop{0};
uint32_t swizzle_cnt{1};
uint32_t core_idx{0};
uint32_t en_shuffle_k{0};
uint32_t ping_flag{0};
};
template <bool withSyncAll, uint32_t swizzleDir, DataFormat formatA = DataFormat::ND,
DataFormat formatB = DataFormat::NZ>
class PpMatmulW8a8Aic
{
using InDtype = int8_t;
using OutDtype = int32_t;
using AccumDtype = int32_t;
template <DataFormat srcFormat, DataFormat dstFormat>
using CopyGmToCbuf = gm_to_l1<ArchType::ASCEND_V220, InDtype, srcFormat, dstFormat>;
using LoadCbufToCa = l1_to_l0_a<ArchType::ASCEND_V220, InDtype, false, DataFormat::ZN, DataFormat::ZZ>;
using LoadCbufToCb = l1_to_l0_b<ArchType::ASCEND_V220, InDtype, true, DataFormat::ZN, DataFormat::NZ>;
using Mmad = mmad<ArchType::ASCEND_V220, InDtype, InDtype, AccumDtype, false>;
using CopyCcToGm = l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, OutDtype, AccumDtype>;
static constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768;
static constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144;
static constexpr uint64_t BLOCK_SIZE_16 = 16;
static constexpr uint64_t BLOCK_SIZE_32 = 32;
static constexpr uint64_t CUBE_MATRIX_SIZE_512 = 512;
static constexpr uint64_t CONST_4 = 4;
static constexpr uint64_t CONST_8 = 8;
static constexpr uint64_t CONST_32 = 32;
static constexpr uint64_t CONST_64 = 64;
static constexpr uint64_t CONST_128 = 128;
public:
__aicore__ PpMatmulW8a8Aic() {};
__aicore__ __force_inline__ void Init(GM_ADDR gmA, GM_ADDR gmB, GM_ADDR gmC, PpMatmulTilingData &tilingdata,
uint32_t mode)
{
gm_a.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmA));
gm_b.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmB));
gm_c.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmC));
batch_size = tilingdata.numBatch;
m = tilingdata.m;
k = tilingdata.k;
n = tilingdata.n;
m0 = tilingdata.m0;
k0 = tilingdata.k0;
n0 = tilingdata.n0;
m_loop = tilingdata.mLoop;
k_loop = tilingdata.kLoop;
n_loop = tilingdata.nLoop;
core_loop = tilingdata.coreLoop;
swizzle_cnt = tilingdata.swizzleCount;
en_shuffle_k = tilingdata.enShuffleK;
core_num = tilingdata.blockDim;
load_all_Amat_flag = tilingdata.enLoadAllAmat;
b0mat_pingpong_buffer_len = tilingdata.b0matPingPongBufferLen;
core_idx = AscendC::GetBlockIdx();
ping_flag = 1;
MM1_MM2_mode = mode; // MM1 or MM2
InitBuffer();
return;
}
__aicore__ __force_inline__ uint64_t GetOffsetA(const uint64_t batchIdx, const uint64_t mIdx, uint64_t kIdx)
{
return batchIdx * m * k + mIdx * m0 * k + kIdx * k0;
}
__aicore__ __force_inline__ uint64_t GetOffsetB(const uint64_t batchIdx, const uint64_t kIdx, uint64_t nIdx)
{
if constexpr (formatB == DataFormat::ND) {
return batchIdx * k * n + nIdx * n0 * k + kIdx * k0;
} else {
return batchIdx * RoundUp<16>(n) * RoundUp<32>(k) + kIdx * k0 * RoundUp<16>(n) + nIdx * n0 * CONST_32;
}
}
__aicore__ __force_inline__ void CopyTileA(AscendC::LocalTensor<InDtype> &dstTensor,
const AscendC::GlobalTensor<InDtype> &srcTensor, const uint64_t m_actual,
const uint64_t m_round, const uint64_t k_actual, const uint64_t k_round)
{
if ((m == 1) || (m_actual == 1)) {
CopyGmToCbuf<formatA, DataFormat::ND>(dstTensor, // dst
srcTensor, // src
1, BLOCK_SIZE_16, 1, k_actual, k_round, k);
} else {
CopyGmToCbuf<formatA, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
m_actual, // nTileActual
m_round, // nTileCeil
n, // nVal
k_actual, // dTileActual
k_round, // dTileCeil
k); // dVal
}
}
__aicore__ __force_inline__ void CopyTileB(const AscendC::LocalTensor<InDtype> &dstTensor,
const AscendC::GlobalTensor<InDtype> &srcTensor, const uint64_t k_actual,
const uint64_t k_round, const uint64_t n_actual, const uint64_t n_round)
{
if constexpr (formatB == DataFormat::ND) {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
n_actual, // nTileActual
n_round, // nTileCeil
n, // nVal
k_actual, // dTileActual
k_round, // dTileCeil
k); // dVal
} else {
CopyGmToCbuf<formatB, DataFormat::NZ>(dstTensor, // dst
srcTensor, // src
n_actual, // nTileActual
n_round, // nTileCeil
RoundUp<16>(n), // nVal
k_actual, // dTileActual
k_round, // dTileCeil
RoundUp<32>(k)); // dVal
}
}
__aicore__ __force_inline__ void PreloadWeight()
{
if (core_idx < core_num) {
uint64_t m_idx = 0;
uint64_t n_idx = 0;
GetBaseBlockIdx(core_idx, m_idx, n_idx);
uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0;
uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx);
uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0;
uint64_t k_round = RoundUp<BLOCK_SIZE_32>(k_actual);
uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0;
uint64_t n_round = RoundUp<BLOCK_SIZE_16>(n_actual);
CopyTileB(l1_base_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round);
}
if (core_idx < core_num && k_loop > 1) {
uint64_t m_idx = 0;
uint64_t n_idx = 0;
GetBaseBlockIdx(core_idx, m_idx, n_idx);
uint64_t shuffle_k = en_shuffle_k ? (core_idx + 1) % k_loop : 1;
uint64_t offset_b = GetOffsetB(0, shuffle_k, n_idx);
uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0;
uint64_t k_round = RoundUp<BLOCK_SIZE_32>(k_actual);
uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0;
uint64_t n_round = RoundUp<BLOCK_SIZE_16>(n_actual);
CopyTileB(l1_base_b[b0mat_pingpong_buffer_len], gm_b[offset_b], k_actual, k_round, n_actual, n_round);
}
}
__aicore__ __force_inline__ void Process();
private:
__aicore__ __force_inline__ void InitBuffer()
{
AsdopsBuffer<ArchType::ASCEND_V220> buf;
l1_base_a = buf.template GetBuffer<BufferType::ASCEND_CB, InDtype>(0);
// try load all A matrix
uint32_t a_l1_size = RoundUp<BLOCK_SIZE_16>(m) * RoundUp<BLOCK_SIZE_32>(k);
if (!load_all_Amat_flag) {
a_l1_size = RoundUp<CUBE_MATRIX_SIZE_512>(m0 * k0);
}
l1_base_b = l1_base_a[a_l1_size];
l0a_base = buf.template GetBuffer<BufferType::ASCEND_L0A, InDtype>(0);
l0b_base = buf.template GetBuffer<BufferType::ASCEND_L0B, InDtype>(0);
l0c_buf = buf.template GetBuffer<BufferType::ASCEND_L0C, AccumDtype>(0);
}
__aicore__ __force_inline__ void GetBaseBlockIdx(uint64_t index, uint64_t &m_idx, uint64_t &n_idx)
{
uint64_t in_batch_idx = index % (m_loop * n_loop);
if constexpr (swizzleDir == 0) { // Zn
uint64_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt;
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop);
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop);
uint64_t n_row = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_row = m_loop - swizzle_cnt * tile_block_idx;
}
m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row;
n_idx = in_tile_block_idx / n_row;
if ((tile_block_idx & 0b1) != 0) {
n_idx = n_loop - n_idx - 1;
}
} else { // Nz
uint64_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt;
uint64_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop);
uint64_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop);
uint64_t n_col = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_col = n_loop - swizzle_cnt * tile_block_idx;
}
m_idx = in_tile_block_idx / n_col;
n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col;
if ((tile_block_idx & 0b1) != 0) {
m_idx = m_loop - m_idx - 1;
}
}
return;
}
private:
AscendC::GlobalTensor<InDtype> gm_a;
AscendC::GlobalTensor<InDtype> gm_b;
AscendC::GlobalTensor<OutDtype> gm_c;
AscendC::LocalTensor<InDtype> l1_base_a;
AscendC::LocalTensor<InDtype> l1_base_b;
AscendC::LocalTensor<InDtype> l0a_base;
AscendC::LocalTensor<InDtype> l0b_base;
AscendC::LocalTensor<AccumDtype> l0c_buf;
uint64_t bias_bt{0};
uint32_t core_num{0};
uint32_t batch_size{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
uint32_t m_loop{0};
uint32_t n_loop{0};
uint32_t k_loop{0};
uint32_t core_loop{0};
uint32_t core_idx{0};
uint32_t ping_flag{0};
uint32_t swizzle_cnt{1};
uint32_t en_shuffle_k{0};
uint32_t MM1_MM2_mode{0};
uint64_t b0mat_pingpong_buffer_len{0};
bool load_all_Amat_flag{false};
};
template <bool withSyncAll, uint32_t swizzleDir, DataFormat formatA, DataFormat formatB>
__aicore__ __force_inline__ void PpMatmulW8a8Aic<withSyncAll, swizzleDir, formatA, formatB>::Process()
{
using LocalTensor = AscendC::LocalTensor<InDtype>;
if (core_idx >= core_num) {
if (MM1_MM2_mode == 0) {
WaitFlagDev(AIC_MM1_START);
} else if (MM1_MM2_mode == 1) {
WaitFlagDev(AIC_MM2_START);
}
return;
}
SET_FLAG(MTE1, MTE2, EVENT_ID0);
SET_FLAG(MTE1, MTE2, EVENT_ID1);
SET_FLAG(MTE1, MTE2, EVENT_ID2);
SET_FLAG(MTE1, MTE2, EVENT_ID3);
SET_FLAG(M, MTE1, EVENT_ID0);
SET_FLAG(M, MTE1, EVENT_ID1);
SET_FLAG(FIX, M, EVENT_ID0);
SET_FLAG(FIX, MTE2, EVENT_ID0);
SET_FLAG(MTE1, MTE2, EVENT_ID7);
for (uint64_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) {
uint64_t batch_idx = loop_idx / n_loop / m_loop;
uint64_t m_idx = 0;
uint64_t n_idx = 0;
GetBaseBlockIdx(loop_idx, m_idx, n_idx);
uint64_t offset_a;
uint64_t offset_b;
uint64_t offset_bias;
uint64_t offset_a_next;
uint64_t offset_b_next;
uint64_t offset_c = batch_idx * m * n + m_idx * m0 * n + n_idx * n0;
uint64_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0;
uint64_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0;
uint64_t m_round = 0;
uint64_t n_round = 0;
uint64_t shuffle_k = en_shuffle_k ? core_idx % k_loop : 0;
uint64_t m_round_16 = RoundUp<BLOCK_SIZE_16>(m_actual);
uint64_t m_round_32 = RoundUp<BLOCK_SIZE_32>(m_actual);
m_round = m_round_16;
n_round = RoundUp<BLOCK_SIZE_16>(n_actual);
uint64_t mn_max = m_round > n_round ? m_round : n_round;
uint64_t k_part_len = 0;
k_part_len = L0_PINGPONG_BUFFER_LEN / mn_max / BLOCK_SIZE_32 * BLOCK_SIZE_32;
offset_b = GetOffsetB(batch_idx, shuffle_k, n_idx);
offset_bias = batch_idx * n + n_idx * n0;
uint64_t k_actual = (shuffle_k == k_loop - 1) ? k - shuffle_k * k0 : k0;
uint64_t k_round = RoundUp<BLOCK_SIZE_32>(k_actual);
auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
// Wait after Scalar
if (loop_idx == core_idx) {
if (MM1_MM2_mode == 0) {
WaitFlagDev(AIC_MM1_START);
} else if (MM1_MM2_mode == 1) {
WaitFlagDev(AIC_MM2_START);
}
}
WAIT_FLAG(MTE1, MTE2, event_id);
LocalTensor l1_buf_a =
load_all_Amat_flag ? l1_base_a : (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]);
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len];
if (load_all_Amat_flag) {
if (loop_idx == core_idx) {
offset_a = GetOffsetA(batch_idx, m_idx, 0);
uint64_t k_actual_first = k;
uint64_t k_round_first = RoundUp<BLOCK_SIZE_32>(k_actual_first);
CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual_first, k_round_first);
}
} else {
offset_a = GetOffsetA(batch_idx, m_idx, shuffle_k);
CopyTileA(l1_buf_a, gm_a[offset_a], m_actual, m_round, k_actual, k_round);
}
SET_FLAG(MTE2, MTE1, event_id);
WAIT_FLAG(MTE1, MTE2, event_id + CONST_2);
// The first weight matrix block is loaded in advance.
if (loop_idx != core_idx) {
CopyTileB(l1_buf_b, gm_b[offset_b], k_actual, k_round, n_actual, n_round);
}
SET_FLAG(MTE2, MTE1, event_id + CONST_2);
for (uint64_t k_idx = 0; k_idx < k_loop; k_idx++) {
shuffle_k = en_shuffle_k ? (k_idx + core_idx) % k_loop : k_idx;
uint32_t k_actual = (shuffle_k == (k_loop - 1)) ? (k - shuffle_k * k0) : k0;
uint32_t k_round = RoundUp<BLOCK_SIZE_32>(k_actual);
uint32_t k_part_loop = (k_actual + k_part_len - 1) / k_part_len;
// --------- load whole A in l1a addr change -------------
LocalTensor l1_buf_a = load_all_Amat_flag ? (l1_base_a[k_idx * m0 * k0 * sizeof(int8_t)])
: (ping_flag ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]);
LocalTensor l1_buf_b = ping_flag ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len];
auto event_id = ping_flag ? EVENT_ID0 : EVENT_ID1;
if (k_idx < k_loop - 1) {
uint64_t shuffle_k_next = en_shuffle_k ? (core_idx + k_idx + 1) % k_loop : k_idx + 1;
offset_b_next = GetOffsetB(batch_idx, shuffle_k_next, n_idx);
uint32_t k_actual_next = (shuffle_k_next == (k_loop - 1)) ? (k - shuffle_k_next * k0) : k0;
uint32_t k_round_next = RoundUp<BLOCK_SIZE_32>(k_actual_next);
LocalTensor l1_buf_a_next =
load_all_Amat_flag ? l1_base_a : ((1 - ping_flag) ? l1_base_a : l1_base_a[L1_PINGPONG_BUFFER_LEN]);
LocalTensor l1_buf_b_next = (1 - ping_flag) ? l1_base_b : l1_base_b[b0mat_pingpong_buffer_len];
auto event_id_next = (1 - ping_flag) ? EVENT_ID0 : EVENT_ID1;
WAIT_FLAG(MTE1, MTE2, event_id_next);
if (!load_all_Amat_flag) {
offset_a_next = GetOffsetA(batch_idx, m_idx, shuffle_k_next);
CopyTileA(l1_buf_a_next, gm_a[offset_a_next], m_actual, m_round, k_actual_next, k_round_next);
}
SET_FLAG(MTE2, MTE1, event_id_next);
WAIT_FLAG(MTE1, MTE2, event_id_next + CONST_2);
// The second weight matrix is preloaded.
if (loop_idx != core_idx || k_idx != 0) {
CopyTileB(l1_buf_b_next, gm_b[offset_b_next], k_actual_next, k_round_next, n_actual, n_round);
}
SET_FLAG(MTE2, MTE1, event_id_next + CONST_2);
}
for (int k_part_idx = 0; k_part_idx < k_part_loop; k_part_idx++) {
uint32_t k0_round = (k_part_idx < k_part_loop - 1) ? k_part_len : k_round - k_part_idx * k_part_len;
uint32_t k0_actual = (k_part_idx < k_part_loop - 1) ? k_part_len : k_actual - k_part_idx * k_part_len;
auto mte1_mad_ping_flag = 1 - k_part_idx % 2;
auto mte1_mad_event_id = mte1_mad_ping_flag ? EVENT_ID0 : EVENT_ID1;
AscendC::LocalTensor<InDtype> l0a_buf = l0a_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN];
AscendC::LocalTensor<InDtype> l0b_buf = l0b_base[(k_part_idx % 2) * L0_PINGPONG_BUFFER_LEN];
// *** load matrix A from L1 to L0A
if (k_part_idx == 0) {
WAIT_FLAG(MTE2, MTE1, event_id);
}
WAIT_FLAG(M, MTE1, mte1_mad_event_id);
if ((m == 1) || (m_actual == 1)) {
l1_to_l0_a<ArchType::ASCEND_V220, InDtype, false, DataFormat::VECTOR, DataFormat::VECTOR>(
l0a_buf, l1_buf_a[k_part_idx * k_part_len],
0, // mTileCeil
CeilDiv<CUBE_MATRIX_SIZE_512>(k0_round), // kPartCeil
0, // mSrcStride
1, // kSrcStride
0, // mDstStride
0); // kDstStride
} else {
LoadCbufToCa(l0a_buf, // l0Tensor
l1_buf_a[k_part_idx * k_part_len * m_round], // l1Tensor
m_round, // mTileCeil
k0_round, // kPartCeil
1, // mSrcStride
m_round / BLOCK_SIZE_16, // kSrcStride
k0_round / BLOCK_SIZE_32, // mDstStride
1); // kDstStride
}
if (k_part_idx == k_part_loop - 1) {
SET_FLAG(MTE1, MTE2, event_id);
}
// *** load matrix B from L1 to L0B
if (k_part_idx == 0) {
WAIT_FLAG(MTE2, MTE1, event_id + CONST_2);
}
LoadCbufToCb(l0b_buf, // l0Tensor
l1_buf_b[k_part_idx * k_part_len * n_round], // l1Tensor
n_round, // nTileCeil
k0_round, // kPartCeil
1, // nSrcStride
n_round / BLOCK_SIZE_16, // kSrcStride
1, // nDstStride
k0_round / BLOCK_SIZE_32); // kDstStride
if (k_part_idx == k_part_loop - 1) {
SET_FLAG(MTE1, MTE2, event_id + CONST_2);
}
SET_FLAG(MTE1, M, mte1_mad_event_id);
WAIT_FLAG(MTE1, M, mte1_mad_event_id);
bool init_c = (k_idx == 0 && k_part_idx == 0);
if (init_c) {
WAIT_FLAG(FIX, M, EVENT_ID0);
}
Mmad(l0c_buf, l0a_buf, l0b_buf,
m_actual, // m
n_actual, // n
k0_actual, // k
init_c); // cmatrixInitVal
PIPE_BARRIER(M);
SET_FLAG(M, MTE1, mte1_mad_event_id);
}
ping_flag = 1 - ping_flag;
}
SET_FLAG(M, FIX, EVENT_ID0);
WAIT_FLAG(M, FIX, EVENT_ID0);
// copy from L0C to gm
CopyCcToGm(gm_c[offset_c], // dst
l0c_buf, // src
m_actual, // MSize
n_actual, // NSize
m_round_16, // srcStride
n); // dstStride_dst_D
SET_FLAG(FIX, M, EVENT_ID0);
if constexpr (!withSyncAll) {
FftsCrossCoreSync<PIPE_FIX, SYNC_MODE>(MMAIC);
if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 0) {
WaitFlagDev(MMAIV);
}
}
}
WAIT_FLAG(MTE1, MTE2, EVENT_ID0);
WAIT_FLAG(MTE1, MTE2, EVENT_ID1);
WAIT_FLAG(MTE1, MTE2, EVENT_ID2);
WAIT_FLAG(MTE1, MTE2, EVENT_ID3);
WAIT_FLAG(M, MTE1, EVENT_ID0);
WAIT_FLAG(M, MTE1, EVENT_ID1);
WAIT_FLAG(FIX, M, EVENT_ID0);
WAIT_FLAG(FIX, MTE2, EVENT_ID0);
WAIT_FLAG(MTE1, MTE2, EVENT_ID7);
}
#endif
#if defined(__DAV_C220_VEC__)
template <typename OutDtype, bool withSyncAll, QuantMode quantMode>
class PpMatmulW8a8Aiv
{
using InDtype = int32_t;
using ScaleDtype = float;
using BiasDtype = int32_t;
public:
__aicore__ PpMatmulW8a8Aiv() {};
__aicore__ __force_inline__ void Init(GM_ADDR gmInput, GM_ADDR gmOutput, GM_ADDR gmDescale, GM_ADDR gmPerTensorBias,
GM_ADDR gmPertokenDescale, const PpMatmulTilingData &gmTilingData)
{
gmInput_.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmInput));
gmOutput_.SetGlobalBuffer(reinterpret_cast<__gm__ OutDtype *>(gmOutput));
gmPerTensorScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmDescale));
gmPerTensorBias_.SetGlobalBuffer(reinterpret_cast<__gm__ BiasDtype *>(gmPerTensorBias));
gmPerTokenScale_.SetGlobalBuffer(reinterpret_cast<__gm__ ScaleDtype *>(gmPertokenDescale));
batch_size = gmTilingData.numBatch;
m = gmTilingData.m;
k = gmTilingData.k;
n = gmTilingData.n;
m0 = gmTilingData.m0;
k0 = gmTilingData.k0;
n0 = gmTilingData.n0;
m_loop = gmTilingData.mLoop;
k_loop = gmTilingData.kLoop;
n_loop = gmTilingData.nLoop;
core_loop = gmTilingData.coreLoop;
swizzle_cnt = gmTilingData.swizzleCount;
swizzlDirect = gmTilingData.swizzleDirect;
en_shuffle_k = gmTilingData.enShuffleK;
AsdopsBuffer<ArchType::ASCEND_V220> buf;
ubInput_ = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
ubTempFp32_ = buf.GetBuffer<BufferType::ASCEND_UB, float>(94 * 1024);
ubOutput_ = buf.GetBuffer<BufferType::ASCEND_UB, OutDtype>(0);
ubPerTensorScale_ = buf.GetBuffer<BufferType::ASCEND_UB, float>(188 * 1024);
block_size = BLOCK_SIZE_32;
core_num = AscendC::GetBlockNum();
core_idx = AscendC::GetBlockIdx() / 2;
ping_flag = 1;
}
__aicore__ __force_inline__ void GetBlockIdx(uint32_t index, uint32_t &m_idx, uint32_t &n_idx)
{
uint32_t in_batch_idx = index % (m_loop * n_loop);
if (swizzlDirect == 0) { // Zn
uint32_t tile_block_loop = (m_loop + swizzle_cnt - 1) / swizzle_cnt;
uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * n_loop);
uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * n_loop);
uint32_t n_row = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_row = m_loop - swizzle_cnt * tile_block_idx;
}
m_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_row;
n_idx = in_tile_block_idx / n_row;
if (tile_block_idx % 2 != 0) {
n_idx = n_loop - n_idx - 1;
}
} else { // Nz
uint32_t tile_block_loop = (n_loop + swizzle_cnt - 1) / swizzle_cnt;
uint32_t tile_block_idx = in_batch_idx / (swizzle_cnt * m_loop);
uint32_t in_tile_block_idx = in_batch_idx % (swizzle_cnt * m_loop);
uint32_t n_col = swizzle_cnt;
if (tile_block_idx == tile_block_loop - 1) {
n_col = n_loop - swizzle_cnt * tile_block_idx;
}
m_idx = in_tile_block_idx / n_col;
n_idx = tile_block_idx * swizzle_cnt + in_tile_block_idx % n_col;
if (tile_block_idx % 2 != 0) {
m_idx = m_loop - m_idx - 1;
}
}
}
__aicore__ __force_inline__ void Process();
private:
AscendC::GlobalTensor<ScaleDtype> gmPerTensorScale_;
AscendC::GlobalTensor<BiasDtype> gmPerTensorBias_;
AscendC::GlobalTensor<ScaleDtype> gmPerTokenScale_;
AscendC::GlobalTensor<InDtype> gmInput_;
AscendC::GlobalTensor<OutDtype> gmOutput_;
AscendC::LocalTensor<int32_t> ubInput_;
AscendC::LocalTensor<float> ubTempFp32_;
AscendC::LocalTensor<OutDtype> ubOutput_;
AscendC::LocalTensor<float> ubPerTensorScale_;
uint32_t core_num{0};
uint32_t batch_size{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
uint32_t m_loop{0};
uint32_t n_loop{0};
uint32_t k_loop{0};
uint32_t core_loop{0};
uint32_t core_idx{0};
uint32_t ping_flag{0};
uint32_t block_size{0};
uint32_t cube_matrix_size{0};
uint32_t swizzle_cnt{1};
uint32_t en_shuffle_k{0};
uint32_t swizzlDirect{0};
uint64_t L1_PINGPONG_BUFFER_LEN{0};
uint32_t L0AB_PINGPONG_BUFFER_LEN{0};
};
template <typename OutDtype, bool withSyncAll, QuantMode quantMode>
__aicore__ __force_inline__ void PpMatmulW8a8Aiv<OutDtype, withSyncAll, quantMode>::Process()
{
uint32_t m_idx = 0;
uint32_t n_idx = 0;
SET_FLAG(V, MTE2, EVENT_ID0);
SET_FLAG(MTE3, V, EVENT_ID0);
SET_FLAG(MTE3, MTE2, EVENT_ID0);
for (uint32_t loop_idx = core_idx; loop_idx < core_loop; loop_idx += core_num) {
GetBlockIdx(loop_idx, m_idx, n_idx);
uint64_t batch_idx = loop_idx / n_loop / m_loop;
uint64_t offsetC = batch_idx * m * n + m_idx * m0 * n + n_idx * n0;
uint32_t m_actual = (m_idx == (m_loop - 1)) ? (m - m_idx * m0) : m0;
uint32_t n_actual = (n_idx == (n_loop - 1)) ? (n - n_idx * n0) : n0;
uint32_t m_round = RoundUp<CONST_8>(m_actual);
uint32_t n_round = RoundUp<CONST_8>(n_actual);
uint32_t n_round_16 = RoundUp<BLOCK_SIZE_16>(n_actual);
uint32_t m_actual_per_vec = m_actual / AscendC::GetTaskRation();
uint32_t m_offset = m + m_idx * m0;
if (GetSubBlockidx() != 0) {
offsetC += m_actual_per_vec * n;
m_offset += m_actual_per_vec;
m_actual_per_vec = m_actual - m_actual_per_vec;
}
if constexpr (!withSyncAll) {
if (m_actual_per_vec == 0) {
WaitFlagDev(MMAIC);
if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) {
FftsCrossCoreSync<PIPE_MTE3, SYNC_MODE>(MMAIV);
}
continue;
}
}
uint64_t offsetScale = batch_idx * n + n_idx * n0;
bool aligned_s32 = ((n & 0b111) == 0); // 32B aligned
bool aligned_f16 = ((n & 0b1111) == 0); // 32B aligned
WAIT_FLAG(V, MTE2, EVENT_ID0);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
if (aligned_s32) {
gm_to_ub<ArchType::ASCEND_V220, BiasDtype>(ubPerTensorScale_.ReinterpretCast<BiasDtype>(),
gmPerTensorBias_[offsetScale],
0, // sid
1, // nBurst
n_round * sizeof(BiasDtype) / BLOCK_SIZE_32, // lenBurst
0, // srcStride
0); // dstStride
} else {
gm_to_ub_align<ArchType::ASCEND_V220, BiasDtype>(ubPerTensorScale_.ReinterpretCast<BiasDtype>(),
gmPerTensorBias_[offsetScale],
0, // sid
1, // nBurst
n_actual * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0); // dstGap
}
} else {
if (aligned_s32) {
gm_to_ub<ArchType::ASCEND_V220, float>(ubPerTensorScale_, gmPerTensorScale_[offsetScale],
0, // sid
1, // nBurst
n_round * 4 / BLOCK_SIZE_32, // lenBurst
0, // srcStride
0); // dstStride
} else {
gm_to_ub_align<ArchType::ASCEND_V220, float>(ubPerTensorScale_, gmPerTensorScale_[offsetScale],
0, // sid
1, // nBurst
n_actual * sizeof(float), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0); // dstGap
}
}
if constexpr (!withSyncAll) {
WaitFlagDev(MMAIC);
}
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
if (aligned_s32) {
gm_to_ub<ArchType::ASCEND_V220, int32_t>(ubInput_, gmInput_[offsetC],
0, // sid
m_actual_per_vec, // nBurst
n_round / 8, // lenBurst
(n - n_round) / 8, // srcStride
0 // dstStride
);
} else {
gm_to_ub_align<ArchType::ASCEND_V220, int32_t>(ubInput_, gmInput_[offsetC],
0, // sid
m_actual_per_vec, // nBurst
n_actual * sizeof(int32_t), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
(n - n_actual) * sizeof(int32_t), // srcGap
0 // dstGap
);
}
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE3, V, EVENT_ID0);
uint32_t nRepeatCnt = CeilDiv<CONST_64>(n_actual);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
AscendC::SetMaskCount();
AscendC::SetVectorMask<BiasDtype, AscendC::MaskMode::COUNTER>(n_round);
for (uint32_t i = 0; i < m_actual_per_vec; ++i) {
// add_v<ArchType::ASCEND_V220, BiasDtype>(ubInput_[i * n_round],
// ubInput_[i * n_round],
// ubPerTensorScale_.ReinterpretCast<BiasDtype>(),
// (uint8_t)(nRepeatCnt), // repeat
// (uint8_t)1, // dstBlockStride
// (uint8_t)1, // src0BlockStride
// (uint8_t)1, // src1BlockStride
// (uint8_t)8, // dstRepeatStride
// (uint8_t)8, // src0RepeatStride
// (uint8_t)8 // src1RepeatStride
// );
AscendC::Add<BiasDtype, false>(ubInput_[i * n_round], ubInput_[i * n_round],
ubPerTensorScale_.ReinterpretCast<BiasDtype>(),
AscendC::MASK_PLACEHOLDER, 1,
AscendC::BinaryRepeatParams((uint8_t)1, (uint8_t)1, (uint8_t)1,
(uint8_t)8, (uint8_t)8, (uint8_t)8));
}
AscendC::ResetMask();
SetMasknorm();
SET_FLAG(V, MTE2, EVENT_ID0);
WAIT_FLAG(V, MTE2, EVENT_ID0);
if (aligned_s32) {
gm_to_ub<ArchType::ASCEND_V220, ScaleDtype>(ubPerTensorScale_, gmPerTensorScale_[offsetScale],
0, // sid
1, // nBurst
n_round * sizeof(ScaleDtype) / BLOCK_SIZE_32, // lenBurst
0, // srcStride
0 // dstStride
);
} else {
gm_to_ub_align<ArchType::ASCEND_V220, ScaleDtype>(ubPerTensorScale_, gmPerTensorScale_[offsetScale],
0, // sid
1, // nBurst
n_actual * sizeof(ScaleDtype), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
0 // dstGap
);
}
SET_FLAG(MTE2, V, EVENT_ID0);
WAIT_FLAG(MTE2, V, EVENT_ID0);
}
// CASTF32 * f32 tf16
constexpr uint32_t maxRepeat = 255;
constexpr uint32_t perRepeatNum = maxRepeat * 64;
uint32_t loopCnt = (m_actual_per_vec * n_actual + perRepeatNum - 1) / perRepeatNum;
for (uint32_t i = 0; i < loopCnt; i++) {
conv_v<ArchType::ASCEND_V220, int32_t, float>(ubInput_.ReinterpretCast<float>()[perRepeatNum * i],
ubInput_[perRepeatNum * i],
(uint8_t)maxRepeat, // repeat
(uint16_t)1, // dstBlockStride
(uint16_t)1, // srcBlockStride
(uint16_t)8, // dstRepeatStride
(uint16_t)8 // srcRepeatStride
);
}
AscendC::PipeBarrier<PIPE_V>();
for (uint32_t i = 0; i < m_actual_per_vec; ++i) {
mul_v<ArchType::ASCEND_V220, float>(ubTempFp32_[i * n_round],
ubInput_.ReinterpretCast<float>()[i * n_round],
ubPerTensorScale_.ReinterpretCast<float>(),
(uint8_t)(nRepeatCnt), // repeat
(uint8_t)1, // dstBlockStride
(uint8_t)1, // src0BlockStride
(uint8_t)1, // src1BlockStride
(uint8_t)8, // dstRepeatStride
(uint8_t)8, // src0RepeatStride
(uint8_t)8 // src1RepeatStride
);
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
AscendC::PipeBarrier<PIPE_V>();
float perTokenDescale = gmPerTokenScale_.GetValue(m_offset + i);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
AscendC::Muls(ubTempFp32_[i * n_round], ubTempFp32_[i * n_round], perTokenDescale, n_round);
}
AscendC::PipeBarrier<PIPE_V>();
}
SET_FLAG(V, MTE2, EVENT_ID0);
AscendC::PipeBarrier<PIPE_V>();
if (n_actual % 16 > 8) {
for (uint32_t i = 0; i < loopCnt; i++) {
if constexpr (std::is_same_v<OutDtype, __bf16>) {
convr_v<ArchType::ASCEND_V220, float, OutDtype>(ubOutput_[perRepeatNum * i],
ubTempFp32_[perRepeatNum * i],
(uint8_t)maxRepeat, // repeat
(uint16_t)1, // dstBlockStride
(uint16_t)1, // srcBlockStride
(uint16_t)4, // dstRepeatStride
(uint16_t)8); // srcRepeatStride
} else {
conv_v<ArchType::ASCEND_V220, float, OutDtype>(ubOutput_[perRepeatNum * i],
ubTempFp32_[perRepeatNum * i],
(uint8_t)maxRepeat, // repeat
(uint16_t)1, // dstBlockStride
(uint16_t)1, // srcBlockStride
(uint16_t)4, // dstRepeatStride
(uint16_t)8); // srcRepeatStride
}
}
} else {
for (uint32_t i = 0; i < m_actual_per_vec; i++) {
if constexpr (std::is_same_v<OutDtype, __bf16>) {
convr_v<ArchType::ASCEND_V220, float, OutDtype>(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i],
(uint8_t)nRepeatCnt, // repeat
(uint16_t)1, // dstBlockStride
(uint16_t)1, // srcBlockStride
(uint16_t)4, // dstRepeatStride
(uint16_t)8); // srcRepeatStride
} else {
conv_v<ArchType::ASCEND_V220, float, OutDtype>(ubOutput_[n_round_16 * i], ubTempFp32_[n_round * i],
(uint8_t)nRepeatCnt, // repeat
(uint16_t)1, // dstBlockStride
(uint16_t)1, // srcBlockStride
(uint16_t)4, // dstRepeatStride
(uint16_t)8); // srcRepeatStride
}
}
}
SET_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(V, MTE3, EVENT_ID0);
if (aligned_f16) {
ub_to_gm<ArchType::ASCEND_V220, OutDtype>(gmOutput_[offsetC], ubOutput_, 0,
m_actual_per_vec, // nBurst
n_round / 16, // lenBurst
0, // srcStride
(n - n_round) / 16 // dstStride
);
} else {
ub_to_gm_align<ArchType::ASCEND_V220, OutDtype>(gmOutput_[offsetC], ubOutput_, 0,
m_actual_per_vec, // nBurst
n_actual * sizeof(OutDtype), // lenBurst
0, // leftPaddingNum
0, // rightPaddingNum
0, // srcGap
(n - n_actual) * sizeof(OutDtype) // dstGap
);
}
SET_FLAG(MTE3, V, EVENT_ID0);
SET_FLAG(MTE3, MTE2, EVENT_ID0);
if constexpr (!withSyncAll) {
if ((loop_idx / core_num + 1) % MAX_HW_SYNC_COUNTER == 1) {
FftsCrossCoreSync<PIPE_MTE3, SYNC_MODE>(MMAIV);
}
}
}
WAIT_FLAG(V, MTE2, EVENT_ID0);
WAIT_FLAG(MTE3, V, EVENT_ID0);
WAIT_FLAG(MTE3, MTE2, EVENT_ID0);
}
#endif
template <typename InDtype, int8_t CACHE_MODE, DataFormat weightFormat1 = DataFormat::NZ,
DataFormat weightFormat2 = DataFormat::NZ, DataFormat weightFormat3 = DataFormat::ND,
QuantMode quantMode = QuantMode::PER_TENSOR_ASYMM_QUANT>
class MLAOperation
{
static constexpr bool mm1WithSyncAll = (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT);
static constexpr uint64_t splitGapC = CACHE_MODE == CACHE_MODE_KVCACHE ? CONST_64 : CONST_0;
using Q_OUT_DTYPE = typename std::conditional_t<CACHE_MODE == CACHE_MODE_INT8_NZCACHE, int8_t, InDtype>;
using K_NOPE_DTYPE = typename std::conditional_t<CACHE_MODE == CACHE_MODE_INT8_NZCACHE, int8_t, InDtype>;
public:
__aicore__ inline MLAOperation(const MlaTilingData &mlaParams_, GM_ADDR tilingGm)
{
blockIdx = AscendC::GetBlockIdx();
#ifdef __DAV_C220_VEC__
sub_block_idx = static_cast<uint64_t>(GetSubBlockidx());
#endif
vectorBlockIdx = (blockIdx / 2) * 2 + sub_block_idx;
this->n = mlaParams_.n;
this->num_core_ = mlaParams_.rmsNumCore1;
this->num_col_1 = mlaParams_.rmsNumCol1;
this->num_col_2 = mlaParams_.rmsNumCol2;
this->num_row = mlaParams_.n;
this->epsilon_ = 1e-6;
this->mlaParams = mlaParams_;
}
__aicore__ inline void Init(GM_ADDR hiddenStateGm, GM_ADDR gamma1Gm, GM_ADDR beta1Gm, GM_ADDR quantScale1Gm,
GM_ADDR quantOffset1Gm, GM_ADDR wdqkvGm, GM_ADDR bias1Gm, GM_ADDR gamma2Gm,
GM_ADDR beta2Gm, GM_ADDR quantScale2Gm, GM_ADDR quantOffset2Gm, GM_ADDR gamma3Gm,
GM_ADDR sin1Gm, GM_ADDR cos1Gm, GM_ADDR sin2Gm, GM_ADDR cos2Gm, GM_ADDR keycacheGm,
GM_ADDR slotMappingGm, GM_ADDR wuqGm, GM_ADDR bias2Gm, GM_ADDR wukGm,
GM_ADDR descale1Gm, GM_ADDR descale2Gm, GM_ADDR gmCtkvScale, GM_ADDR gmQnopeScale,
GM_ADDR qGm, GM_ADDR keycacheOutGm, GM_ADDR qGm2, GM_ADDR keycacheOutGm2, GM_ADDR s1Gm,
GM_ADDR s2Gm, GM_ADDR s3Gm, GM_ADDR s4Gm, GM_ADDR s5Gm)
{
quantScale3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gmCtkvScale));
gamma3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma3Gm));
sin1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin1Gm));
cos1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos1Gm));
keycacheGmTensor1.SetGlobalBuffer(reinterpret_cast<__gm__ K_NOPE_DTYPE *>(keycacheOutGm));
keycacheGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(keycacheOutGm2));
slotMappingGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(slotMappingGm));
descale1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale1Gm));
s2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s2Gm));
s3GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(s3Gm));
s5GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(s5Gm));
#ifdef __DAV_C220_CUBE__
mm_w8a8_aic_1.Init(s1Gm, wdqkvGm, s2Gm, mlaParams.mm1, 0);
mm_w8a8_aic_1.PreloadWeight();
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
mm_w8a8_aic_2.Init(s1Gm, wuqGm, s2Gm, mlaParams.mm2, 1);
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
mm_w8a8_aic_2.Init(s1Gm, wuqGm, s3Gm, mlaParams.mm2, 1);
}
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
mm_ein_sum.Init(s4Gm, wukGm, s1Gm, mlaParams);
} else {
mm_ein_sum.Init(s4Gm, wukGm, qGm, mlaParams);
}
#endif
hiddenStateGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(hiddenStateGm));
gamma1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma1Gm));
quantScale1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale1Gm));
quantOffset1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset1Gm));
wdqkvGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wdqkvGm));
gamma2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(gamma2Gm));
quantScale2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(quantScale2Gm));
quantOffset2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(quantOffset2Gm));
sin2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(sin2Gm));
cos2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(cos2Gm));
wuqGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(wuqGm));
wukGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(wukGm));
descale2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ float *>(descale2Gm));
s1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int8_t *>(s1Gm));
s4GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(s4Gm));
qGmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ Q_OUT_DTYPE *>(qGm));
qGmTensor2.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(qGm2));
bias1gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias1Gm));
bias2gmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(bias2Gm));
beta1GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta1Gm));
beta2GmTensor.SetGlobalBuffer(reinterpret_cast<__gm__ InDtype *>(beta2Gm));
#ifdef __DAV_C220_VEC__
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
mm_w8a8_aiv_1.Init(s2Gm, s3Gm, descale1Gm, bias1Gm, s5Gm, mlaParams.mm1);
mm_w8a8_aiv_2.Init(s2Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2);
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
mm_w8a8_aiv_2.Init(s3Gm, s4Gm, descale2Gm, bias2Gm, s5Gm, mlaParams.mm2);
}
row_work = (num_row + num_core_ - 1) / num_core_;
row_work_ = 0;
uint32_t need_core = (num_row + row_work - 1) / row_work;
if (vectorBlockIdx < need_core - 1) {
row_work_ = row_work;
} else if (vectorBlockIdx == need_core - 1) {
row_work_ = num_row - (need_core - 1) * row_work;
} else {
row_work_ = 0;
}
this->splitN = mlaParams.perTaskNum;
Quant1.Init(quantScale1GmTensor, quantOffset1GmTensor, s5Gm + row_work * vectorBlockIdx * sizeof(float),
descale1Gm, hiddenStateGm, s1Gm, 0, num_col_1,
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1,
vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_1, row_work_, mlaParams);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s3Gm, s1Gm, SPLIT_SIZE_ONE,
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
rmsNormQuant2.Init(gamma2GmTensor, beta2GmTensor, quantScale2GmTensor, quantOffset2GmTensor,
s5Gm + row_work * vectorBlockIdx * sizeof(float), descale1Gm, s2Gm, s1Gm, SPLIT_SIZE_ONE,
num_col_2, 0.000651041666, vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2,
vectorBlockIdx * static_cast<uint64_t>(row_work) * SPLIT_SIZE_TWO, row_work_, mlaParams);
}
ropeFp16.RopeInit(s4Gm, cos2GmTensor, sin2GmTensor, qGmTensor, qGmTensor2, mlaParams);
einSumQuant.Init(s1Gm, gmQnopeScale, qGm, mlaParams);
#endif
}
__aicore__ inline void ProcessCube();
__aicore__ inline void ProcessVector();
private:
constexpr static uint32_t C0_SIZE = 16;
constexpr static uint32_t I8_C0_SIZE = 32;
template <class T1>
__aicore__ inline void RmsNormAndRopeConvergence1(
const AscendC::LocalTensor<T1> &srcTensor, const AscendC::LocalTensor<T1> &gammaTensor,
const AscendC::LocalTensor<T1> &sinTensor, const AscendC::LocalTensor<T1> &cosTensor,
const AscendC::LocalTensor<int32_t> &slotMappingTensor, const uint32_t sN,
const AscendC::LocalTensor<float> &rmsNormTensor, const AscendC::LocalTensor<float> &gammaFp32,
const AscendC::LocalTensor<float> &ropeKTensor, const AscendC::LocalTensor<float> &ropeKRevertTensor,
const AscendC::LocalTensor<float> &calTensor, const AscendC::LocalTensor<T1> &outTmpTensor,
AscendC::LocalTensor<half> &tmpfp16, AscendC::LocalTensor<int8_t> &int8OutTensor, float quantScale3)
{
int64_t slotMapGmOffset = vectorBlockIdx * row_work;
AscendC::DataCopy(gammaTensor, gamma3GmTensor, SPLIT_RMSNRORM_SIZE_ONE);
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
Cast(gammaFp32, gammaTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::DataCopyPad(slotMappingTensor, slotMappingGmTensor[slotMapGmOffset],
AscendC::DataCopyExtParams(1, sN * sizeof(int32_t), 0, 0, 0),
AscendC::DataCopyPadExtParams<int32_t>(false, 0, 8 - sN % 8, 0));
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
mmTensor = calTensor.ReinterpretCast<int32_t>()[SPLIT_SIZE_ONE];
deScaleTensor = calTensor.ReinterpretCast<float>()[SPLIT_SIZE_ONE * 2];
AscendC::DataCopy(deScaleTensor, descale1gmTensor, AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
}
SET_FLAG(MTE2, V, EVENT_ID2);
WAIT_FLAG(MTE2, V, EVENT_ID2);
SET_FLAG(MTE2, S, EVENT_ID2);
WAIT_FLAG(MTE2, S, EVENT_ID2);
for (uint64_t loop = 0; loop < sN; ++loop) {
uint64_t offset = vectorBlockIdx * static_cast<uint64_t>(row_work) * num_col_2 + loop * MM1_OUT_SIZE;
int64_t slotValue = static_cast<int64_t>(slotMappingTensor.GetValue(loop));
if (slotValue == -1) {
continue;
}
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
AscendC::DataCopy(srcTensor, s3GmTensor[offset],
AscendC::DataCopyParams(1, MM1_OUT_SIZE / BLOCK_SIZE_16, 0, 0));
} else {
// quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
AscendC::DataCopy(mmTensor, s2GmTensor[offset], AscendC::DataCopyParams(1, SPLIT_SIZE_ONE / 8, 0, 0));
}
AscendC::DataCopy(sinTensor, sin1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
SPLIT_RMSNRORM_SIZE_TWO);
AscendC::DataCopy(cosTensor, cos1GmTensor[(row_work * vectorBlockIdx + loop) * SPLIT_RMSNRORM_SIZE_TWO],
SPLIT_RMSNRORM_SIZE_TWO);
SET_FLAG(MTE2, V, EVENT_ID0);
// ND
uint64_t cacheStart = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_SIZE_ONE);
uint64_t cacheStart1 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_ONE);
uint64_t cacheStart2 = static_cast<uint64_t>(slotValue) * static_cast<uint64_t>(SPLIT_RMSNRORM_SIZE_TWO);
// NZ
uint32_t outer_idx = slotValue / 128;
uint32_t inner_idx = slotValue % 128;
SET_FLAG(S, MTE3, EVENT_ID0);
/* RmsNorm start */
WAIT_FLAG(MTE2, V, EVENT_ID0);
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
/* DeQuant */
AscendC::Cast(mmTensor.ReinterpretCast<float>(), mmTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Mul(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), deScaleTensor,
SPLIT_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
float perTokenDescale = s5GmTensor.GetValue(row_work * vectorBlockIdx + loop);
SET_FLAG(S, V, EVENT_ID0);
WAIT_FLAG(S, V, EVENT_ID0);
AscendC::Muls(mmTensor.ReinterpretCast<float>(), mmTensor.ReinterpretCast<float>(), perTokenDescale,
SPLIT_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
AscendC::Cast(srcTensor, mmTensor.ReinterpretCast<float>(), AscendC::RoundMode::CAST_RINT,
SPLIT_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
}
Cast(rmsNormTensor, srcTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
Mul(calTensor, rmsNormTensor, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
ReduceSumCustom(calTensor[SPLIT_RMSNRORM_SIZE_ONE], calTensor, calTensor[SPLIT_RMSNRORM_SIZE_ONE * 2],
SPLIT_RMSNRORM_SIZE_ONE);
SET_FLAG(V, S, EVENT_ID1);
WAIT_FLAG(V, S, EVENT_ID1);
float rms = sqrt(calTensor.GetValue(SPLIT_RMSNRORM_SIZE_ONE) / SPLIT_RMSNRORM_SIZE_ONE + epsilon_);
SET_FLAG(S, V, EVENT_ID1);
WAIT_FLAG(S, V, EVENT_ID1);
AscendC::PipeBarrier<PIPE_V>();
Duplicate(calTensor, rms, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
Div(calTensor, rmsNormTensor, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
Mul(rmsNormTensor, gammaFp32, calTensor, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
// quant
Muls(rmsNormTensor, rmsNormTensor, quantScale3, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(tmpfp16, rmsNormTensor, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
CastFromF16ToI8(int8OutTensor, tmpfp16, -128, SPLIT_RMSNRORM_SIZE_ONE);
AscendC::PipeBarrier<PIPE_V>();
} else {
AscendC::PipeBarrier<PIPE_V>();
if (std::is_same<T1, __bf16>::value) {
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_RINT, SPLIT_RMSNRORM_SIZE_ONE);
} else {
Cast(outTmpTensor, rmsNormTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_ONE);
}
}
/* RmsNorm end */
/* Rope K start */
uint64_t revertOffset = SPLIT_RMSNRORM_SIZE_TWO / 2;
Cast(ropeKTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
Cast(ropeKRevertTensor[revertOffset], srcTensor[SPLIT_RMSNRORM_SIZE_ONE], AscendC::RoundMode::CAST_NONE,
revertOffset);
Cast(ropeKRevertTensor, srcTensor[SPLIT_RMSNRORM_SIZE_ONE + revertOffset], AscendC::RoundMode::CAST_NONE,
revertOffset);
Duplicate(calTensor, static_cast<float>(-1), revertOffset);
Duplicate(calTensor[revertOffset], static_cast<float>(1), revertOffset);
AscendC::PipeBarrier<PIPE_V>();
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO], cosTensor, AscendC::RoundMode::CAST_NONE, SPLIT_RMSNRORM_SIZE_TWO);
Cast(calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], sinTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
AscendC::PipeBarrier<PIPE_V>();
Mul(ropeKTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO], ropeKTensor, SPLIT_RMSNRORM_SIZE_TWO);
Mul(ropeKRevertTensor, calTensor[SPLIT_RMSNRORM_SIZE_TWO * 2], ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
AscendC::PipeBarrier<PIPE_V>();
Mul(ropeKRevertTensor, calTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
AscendC::PipeBarrier<PIPE_V>();
Add(ropeKRevertTensor, ropeKTensor, ropeKRevertTensor, SPLIT_RMSNRORM_SIZE_TWO);
AscendC::PipeBarrier<PIPE_V>();
if (std::is_same<T1, __bf16>::value) {
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_RINT,
SPLIT_RMSNRORM_SIZE_TWO);
} else {
Cast(outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], ropeKRevertTensor, AscendC::RoundMode::CAST_NONE,
SPLIT_RMSNRORM_SIZE_TWO);
}
AscendC::PipeBarrier<PIPE_V>();
/* Rope K end */
SET_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(V, MTE3, EVENT_ID0);
WAIT_FLAG(S, MTE3, EVENT_ID0);
if constexpr (CACHE_MODE == CACHE_MODE_KVCACHE) {
DataCopy(keycacheGmTensor1[cacheStart], outTmpTensor, SPLIT_SIZE_ONE);
} else if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
uint64_t cacheSatartI8Nz1 = outer_idx * 128 * 512 + inner_idx * I8_C0_SIZE;
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
// nope:int8 nz
AscendC::DataCopyExtParams outExt;
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / I8_C0_SIZE;
outExt.blockLen = I8_C0_SIZE * sizeof(int8_t);
outExt.srcStride = 0;
outExt.dstStride = (128 * I8_C0_SIZE - I8_C0_SIZE) * sizeof(int8_t);
DataCopyPad(keycacheGmTensor1[cacheSatartI8Nz1], int8OutTensor, outExt);
// rope:T1 nz
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
} else if constexpr (CACHE_MODE == CACHE_MODE_NZCACHE) {
uint64_t cacheSatartNz1 = outer_idx * 128 * 512 + inner_idx * C0_SIZE;
uint64_t cacheSatartNz2 = outer_idx * 128 * 64 + inner_idx * C0_SIZE;
// nope:T1 nz
AscendC::DataCopyExtParams outExt;
outExt.blockCount = SPLIT_RMSNRORM_SIZE_ONE / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor1[cacheSatartNz1], outTmpTensor, outExt);
// rope:T1 nz
outExt.blockCount = SPLIT_RMSNRORM_SIZE_TWO / C0_SIZE;
outExt.blockLen = C0_SIZE * sizeof(T1);
outExt.srcStride = 0;
outExt.dstStride = (128 * C0_SIZE - C0_SIZE) * sizeof(T1);
DataCopyPad(keycacheGmTensor2[cacheSatartNz2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE], outExt);
} else {
// keycache1
DataCopy(keycacheGmTensor1[cacheStart1], outTmpTensor, SPLIT_RMSNRORM_SIZE_ONE);
// keycache2
DataCopy(keycacheGmTensor2[cacheStart2], outTmpTensor[SPLIT_RMSNRORM_SIZE_ONE],
SPLIT_RMSNRORM_SIZE_TWO);
}
SET_FLAG(MTE3, MTE2, EVENT_ID1);
WAIT_FLAG(MTE3, MTE2, EVENT_ID1);
}
}
private:
uint32_t n;
uint32_t splitN;
uint32_t rotaryCoeff;
uint32_t blockIdx;
uint32_t sub_block_idx;
uint32_t vectorBlockIdx;
uint32_t blockOffset;
uint32_t perTaskNum;
uint32_t resTaskNum;
MlaTilingData mlaParams;
uint32_t num_core_;
uint32_t num_col_1;
uint32_t num_col_2;
float epsilon_;
uint32_t num_row;
uint32_t quantMin_;
uint32_t row_work;
uint32_t row_work_;
AsdopsBuffer<ArchType::ASCEND_V220> buf;
AscendC::LocalTensor<int32_t> mmTensor;
AscendC::LocalTensor<float> deScaleTensor;
AscendC::GlobalTensor<InDtype> hiddenStateGmTensor;
AscendC::GlobalTensor<InDtype> gamma1GmTensor;
AscendC::GlobalTensor<InDtype> quantScale1GmTensor;
AscendC::GlobalTensor<int8_t> quantOffset1GmTensor;
AscendC::GlobalTensor<int8_t> wdqkvGmTensor;
AscendC::GlobalTensor<InDtype> gamma2GmTensor;
AscendC::GlobalTensor<InDtype> quantScale2GmTensor;
AscendC::GlobalTensor<InDtype> quantScale3GmTensor;
AscendC::GlobalTensor<int8_t> quantOffset2GmTensor;
AscendC::GlobalTensor<InDtype> gamma3GmTensor;
AscendC::GlobalTensor<InDtype> sin1GmTensor;
AscendC::GlobalTensor<InDtype> cos1GmTensor;
AscendC::GlobalTensor<InDtype> sin2GmTensor;
AscendC::GlobalTensor<InDtype> cos2GmTensor;
AscendC::GlobalTensor<K_NOPE_DTYPE> keycacheGmTensor1;
AscendC::GlobalTensor<InDtype> keycacheGmTensor2;
AscendC::GlobalTensor<int32_t> slotMappingGmTensor;
AscendC::GlobalTensor<int8_t> wuqGmTensor;
AscendC::GlobalTensor<InDtype> wukGmTensor;
// cachemode2-->int8; else bf16
AscendC::GlobalTensor<Q_OUT_DTYPE> qGmTensor;
AscendC::GlobalTensor<InDtype> qGmTensor2;
AscendC::GlobalTensor<int8_t> s1GmTensor;
AscendC::GlobalTensor<int32_t> s2GmTensor;
AscendC::GlobalTensor<InDtype> s3GmTensor;
AscendC::GlobalTensor<int32_t> s4GmTensor;
AscendC::GlobalTensor<float> s5GmTensor;
AscendC::GlobalTensor<float> descale1gmTensor;
AscendC::GlobalTensor<float> descale2gmTensor;
AscendC::GlobalTensor<InDtype> beta1GmTensor;
AscendC::GlobalTensor<InDtype> beta2GmTensor;
AscendC::GlobalTensor<int32_t> bias1gmTensor;
AscendC::GlobalTensor<int32_t> bias2gmTensor;
#ifdef __DAV_C220_CUBE__
PpMatmulW8a8Aic<mm1WithSyncAll, 0, DataFormat::ND, weightFormat1> mm_w8a8_aic_1;
PpMatmulW8a8Aic<false, 0, DataFormat::ND, weightFormat2> mm_w8a8_aic_2;
PpMatmulEinSum<InDtype, InDtype, weightFormat3, false, 0, CONST_64, splitGapC> mm_ein_sum;
#endif
#ifdef __DAV_C220_VEC__
PpMatmulW8a8Aiv<InDtype, mm1WithSyncAll, quantMode> mm_w8a8_aiv_1;
PpMatmulW8a8Aiv<InDtype, false, quantMode> mm_w8a8_aiv_2;
Quant<InDtype, true, false, quantMode, false> Quant1;
RmsNormQuant<InDtype, true, false, quantMode, quantMode == QuantMode::PER_TOKEN_SYMM_QUANT> rmsNormQuant2;
RopeFp16<InDtype, InDtype, Q_OUT_DTYPE, CACHE_MODE> ropeFp16;
EinSumQuant<InDtype, InDtype> einSumQuant;
#endif
};
template <typename InDtype, int8_t CACHE_MODE, DataFormat weightFormat1, DataFormat weightFormat2,
DataFormat weightFormat3, QuantMode quantMode>
__aicore__ inline void
MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, quantMode>::ProcessCube()
{
#ifdef __DAV_C220_CUBE__
mm_w8a8_aic_1.Process();
if constexpr (quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
FftsCrossCoreSync<PIPE_FIX, 0>(MMAIC);
WaitFlagDev(MMAIC);
FftsCrossCoreSync<PIPE_FIX, 2>(MMAIV);
}
mm_w8a8_aic_2.PreloadWeight();
mm_w8a8_aic_2.Process();
mm_ein_sum.Process();
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
FftsCrossCoreSync<PIPE_FIX, 0>(EINSUMOUT);
WaitFlagDev(EINSUMOUT);
FftsCrossCoreSync<PIPE_FIX, 0x2>(EINSUMQUANT);
}
#endif
}
template <typename InDtype, int8_t CACHE_MODE, DataFormat weightFormat1, DataFormat weightFormat2,
DataFormat weightFormat3, QuantMode quantMode>
__aicore__ inline void
MLAOperation<InDtype, CACHE_MODE, weightFormat1, weightFormat2, weightFormat3, quantMode>::ProcessVector()
{
#ifdef __DAV_C220_VEC__
if (row_work_ != 0) {
uint32_t num_col_align_int8 = (num_col_1 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_1 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_1 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 32);
AscendC::LocalTensor<float> res1_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + HIDDTEN_STATE * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64);
Quant1.Launch(output_tensor, input_tensor, scale_tensor, offset_tensor, res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(QUANT1);
WaitFlagDev(QUANT1);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM1_START);
if constexpr (quantMode == QuantMode::PER_TENSOR_ASYMM_QUANT) {
mm_w8a8_aiv_1.Process();
FftsCrossCoreSync<PIPE_MTE3, 0>(RMSNORMQUANT2);
WaitFlagDev(RMSNORMQUANT2);
} else { // quantMode == QuantMode::PER_TOKEN_SYMM_QUANT
WaitFlagDev(MMAIV);
}
if (row_work_ != 0) {
uint32_t num_col_align_int8 = (num_col_2 + REPEAT_TIME_256 - 1) / REPEAT_TIME_256 * REPEAT_TIME_256;
uint32_t num_col_align_f16 = (num_col_2 + REPEAT_TIME_128 - 1) / REPEAT_TIME_128 * REPEAT_TIME_128;
uint32_t num_col_align_f32 = (num_col_2 + REPEAT_TIME_64 - 1) / REPEAT_TIME_64 * REPEAT_TIME_64;
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
AscendC::LocalTensor<InDtype> beta_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2);
AscendC::LocalTensor<InDtype> scale_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2);
AscendC::LocalTensor<int8_t> offset_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 32);
AscendC::LocalTensor<float> res1_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64);
AscendC::LocalTensor<float> res3_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4);
AscendC::LocalTensor<int8_t> output_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(
MM1_OUT_SIZE * 2 + SPLIT_SIZE_TWO * 2 + SPLIT_SIZE_TWO * 2 + 64 + num_col_align_f32 * 4 +
BUF_FACTOR * num_col_align_f32 * 4 + 64 + MM1_OUT_SIZE * 4 * 2 + 32);
rmsNormQuant2.Launch(output_tensor, input_tensor, gamma_tensor, beta_tensor, scale_tensor, offset_tensor,
res1_tensor, res3_tensor);
}
FftsCrossCoreSync<PIPE_MTE3, 0>(MM2);
WaitFlagDev(MM2);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM2_START);
if (row_work_ != 0) {
AscendC::LocalTensor<InDtype> input_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(0);
AscendC::LocalTensor<InDtype> gamma_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2);
AscendC::LocalTensor<InDtype> sin_tensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2);
AscendC::LocalTensor<InDtype> cos_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 2);
AscendC::LocalTensor<int32_t> slotMapping_tensor = buf.GetBuffer<BufferType::ASCEND_UB, int32_t>(
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4);
int32_t rms3_ub_offset =
MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 + 4096 * 32;
AscendC::LocalTensor<float> tmp32_tensor = buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset);
int32_t out_ub_offset = MM1_OUT_SIZE * 2 + SPLIT_RMSNRORM_SIZE_ONE * 2 + SPLIT_RMSNRORM_SIZE_TWO * 4 +
4096 * 32 + SPLIT_RMSNRORM_SIZE_ONE * 3 * 4 + SPLIT_RMSNRORM_SIZE_TWO * 2 * 4 +
MM1_OUT_SIZE * 4 * 2 + 32;
AscendC::LocalTensor<InDtype> temp_tensor = buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(out_ub_offset);
AscendC::LocalTensor<half> tmpfp16;
AscendC::LocalTensor<int8_t> int8OutTensor;
float scale3 = 0;
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
// quantScale3
AscendC::LocalTensor<InDtype> quantScaleTensor =
buf.GetBuffer<BufferType::ASCEND_UB, InDtype>(rms3_ub_offset);
AscendC::LocalTensor<float> floatQuantScaleTensor =
buf.GetBuffer<BufferType::ASCEND_UB, float>(rms3_ub_offset + 32);
// int8out
tmpfp16 = buf.GetBuffer<BufferType::ASCEND_UB, half>(rms3_ub_offset +
SPLIT_RMSNRORM_SIZE_ONE * sizeof(float) * 2);
int8OutTensor = buf.GetBuffer<BufferType::ASCEND_UB, int8_t>(out_ub_offset);
AscendC::DataCopy(quantScaleTensor, quantScale3GmTensor, AscendC::DataCopyParams(1, 1, 0, 0));
SET_FLAG(MTE2, V, EVENT_ID1);
WAIT_FLAG(MTE2, V, EVENT_ID1);
Cast(floatQuantScaleTensor, quantScaleTensor, AscendC::RoundMode::CAST_NONE, 1);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID1);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID1);
scale3 = 1 / (float)(floatQuantScaleTensor.GetValue(0));
}
RmsNormAndRopeConvergence1<InDtype>(
input_tensor, // n * 576
gamma_tensor, // gamma
sin_tensor, // sin
cos_tensor, // cons
slotMapping_tensor, // slotMapping
row_work_, tmp32_tensor, tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO],
tmp32_tensor[SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_ONE + SPLIT_RMSNRORM_SIZE_TWO +
SPLIT_RMSNRORM_SIZE_TWO],
temp_tensor, tmpfp16, int8OutTensor, scale3);
}
mm_w8a8_aiv_2.Process();
FftsCrossCoreSync<PIPE_MTE3, 0>(MM2OUT);
WaitFlagDev(MM2OUT);
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(AIC_MM3_START);
ropeFp16.Process();
if constexpr (CACHE_MODE == CACHE_MODE_INT8_NZCACHE) {
WaitFlagDev(EINSUMQUANT);
einSumQuant.Process();
}
#endif
}
} // namespace MLAPO_BF16