Files
vllm-ascend/csrc/mla_preprocess/op_host/mla_preprocess.h
Chen Chen bcc313e8f2 add mla_preprocess kernel (#3226)
### What this PR does / why we need it?

- Adds the `mla_preprocess` custom kernel to provide an optimized
pre-processing operator for Multi-head Latent Attention (MLA) on Ascend
NPUs.
- Wires the new kernel into the C++ extension pipeline so vLLM can
invoke it directly, cutting Python-side tensor shuffling and memory
copies that previously bottlenecked MLA compilation paths.

### Does this PR introduce any user-facing change?

- No. The change only introduces a low-level kernel; public APIs and
inference behavior remain unchanged.

### How was this patch tested?

- Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.

- vLLM version: v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
2025-10-12 07:39:45 +08:00

699 lines
26 KiB
C++

// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost.git
// https://gitee.com/ascend/op-plugin.git
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#include <fstream>
#include <iostream>
#include <math.h>
#include <stdexcept>
#include "acl/acl.h"
// #include "defines.h"
// #include "torch_helper.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/mla_preprocess_tiling.h"
// #include "aclrtlaunch_mla_preprocess.h"
// namespace sglang {
namespace mlapo {
constexpr uint32_t DIM_2 = 2;
constexpr uint32_t AXES_ALIGN_SIZE = 512;
constexpr uint32_t BASE_BLOCK_STEP = 2;
constexpr uint32_t CONST_16 = 16;
constexpr uint32_t CONST_32 = 32;
constexpr uint32_t CONST_128 = 128;
constexpr uint32_t CONST_256 = 256;
constexpr uint32_t CONST_512 = 512;
constexpr uint32_t L1_BUFFER_SIZE = 524288;
constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 262144;
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN = 131072;
constexpr uint32_t L1_SCALE_SIZE = 4096;
constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512;
constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
constexpr uint32_t UB_SIZE = 196352;
constexpr uint32_t HEADDIM = 64;
constexpr uint32_t FP32_REPEAT_MASK = 64;
constexpr uint32_t FP16_REPEAT_MASK = 128;
constexpr int32_t NUM1 = 1;
constexpr int32_t NUM2 = 2;
constexpr int32_t NUM3 = 3;
constexpr int32_t NUM4 = 4;
constexpr int32_t NUM8 = 8;
constexpr uint32_t INDEX_WDQKV = 5;
constexpr uint32_t INDEX_WUQ = 18;
constexpr uint32_t INDEX_WUK = 20;
constexpr uint32_t MAX_SUPPORT_TOKEN_NUMS = 1024;
inline uint32_t CeilDiv(const uint32_t dividend, const uint32_t divisor)
{
if (divisor == 0) {
return UINT32_MAX;
}
return (dividend + divisor - 1) / divisor;
}
inline uint32_t RoundUp(const uint32_t val, const uint32_t align = 16)
{
if (align == 0) {
return 0;
}
return (val + align - 1) / align * align;
}
inline uint32_t RoundDown(const uint32_t val, const uint32_t align = 16)
{
if (align == 0) {
return 0;
}
return val / align * align;
}
template <typename T = uint32_t>
inline T Max(const T a, const T b)
{
return a > b ? a : b;
}
template <typename T = uint32_t>
inline T Min(const T a, const T b)
{
return a < b ? a : b;
}
struct MlaPreprocess {
enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,
PER_TOKEN_SYMM_QUANT,
PER_TOKEN_ASYMM_QUANT,
NO_QUANT
};
};
using QuantMode = MlaPreprocess::QuantMode;
struct PlatformInfo {
uint32_t coreNum;
uint32_t coreNumAic;
uint32_t coreNumAiv;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l2Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
};
struct OpParam {
uint32_t N;
uint32_t headNum;
int32_t cacheMode;
QuantMode quantMode;
caffe2::TypeMeta inDtype;
};
class PpMatmulTilingApi
{
public:
PpMatmulTilingApi(struct PlatformInfo &platformInfo, uint32_t numBatch, uint32_t m, uint32_t k, uint32_t n,
bool transA, bool transB, bool enDequant, bool deqOnTheFly)
: platformInfo_(platformInfo),
numBatch_(numBatch),
m_(m),
k_(k),
n_(n),
transA_(transA),
transB_(transB),
enDequant_(enDequant),
deqOnTheFly_(deqOnTheFly)
{
inDataSize_ = enDequant ? sizeof(uint8_t) : sizeof(uint16_t);
}
void GetTilingData(PpMatmulTilingData &tiling);
private:
void GetTileSize();
float GetCost(const uint32_t m0, const uint32_t n0);
void UpdateTileSize(const uint32_t m0, const uint32_t n0);
void Swizzle();
uint32_t ComputeL1AbSize();
uint32_t ComputeK0ForABpingpong(uint32_t l1AbSize);
bool IsLoadAllAmat(uint32_t l1AbSize);
uint32_t ComputeK0ForOnlyBpingpong(uint32_t l1AbSize);
private:
uint32_t numBatch_{0};
uint32_t m_{0};
uint32_t k_{0};
uint32_t n_{0};
uint32_t m0_{0};
uint32_t k0_{0};
uint32_t n0_{0};
uint32_t mLoop_{0};
uint32_t kLoop_{0};
uint32_t nLoop_{0};
uint32_t coreLoop_{0};
uint32_t swizzleCount_{0};
uint32_t blockDim_{0};
uint32_t swizzleDirect_{0};
uint32_t inDataSize_{0};
uint32_t b0matPingPongBufferLen_{L1_PINGPONG_BUFFER_LEN};
bool transA_{false};
bool transB_{false};
bool enDequant_{false};
bool enShuffleK_{false};
bool enLoadAllAmat_{false};
bool deqOnTheFly_{false};
struct PlatformInfo platformInfo_;
};
void PpMatmulTilingApi::GetTilingData(PpMatmulTilingData &tiling)
{
GetTileSize();
tiling.numBatch = numBatch_;
tiling.m = m_;
tiling.k = k_;
tiling.n = n_;
tiling.m0 = m0_;
tiling.k0 = k0_;
tiling.n0 = n0_;
tiling.mLoop = mLoop_;
tiling.kLoop = kLoop_;
tiling.nLoop = nLoop_;
tiling.coreLoop = coreLoop_;
tiling.swizzleCount = swizzleCount_;
tiling.swizzleDirect = swizzleDirect_;
tiling.enShuffleK = static_cast<uint32_t>(enShuffleK_);
tiling.blockDim = blockDim_;
tiling.enLoadAllAmat = static_cast<uint32_t>(enLoadAllAmat_);
tiling.b0matPingPongBufferLen = b0matPingPongBufferLen_;
}
void PpMatmulTilingApi::GetTileSize()
{
bool priFlag = !(m_ < n_);
uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16;
uint32_t priAxes = RoundUp(priFlag ? m_ : n_, CONST_16);
uint32_t subAxes = RoundUp(priFlag ? n_ : m_, roundBase);
float minCost = __FLT_MAX__;
uint32_t maxAxes0 = AXES_ALIGN_SIZE;
uint32_t maxPriAxes0 = Min(maxAxes0, priAxes);
uint32_t maxSubAxes0 = Min(maxAxes0, subAxes);
for (uint32_t priAxes0 = CONST_16; priAxes0 <= maxPriAxes0; priAxes0 *= BASE_BLOCK_STEP) {
for (uint32_t subAxes0 = CONST_16; subAxes0 <= maxSubAxes0; subAxes0 *= BASE_BLOCK_STEP) {
if (priAxes0 * subAxes0 * sizeof(float) > platformInfo_.l0cSize) {
continue;
}
uint32_t newM0 = priFlag ? priAxes0 : subAxes0;
uint32_t newN0 = priFlag ? subAxes0 : priAxes0;
if (newN0 > CONST_256 && enDequant_) {
continue;
}
float cost = GetCost(newM0, newN0);
if (cost < minCost) {
minCost = cost;
UpdateTileSize(newM0, newN0);
}
}
}
Swizzle();
uint32_t l1AbSize = ComputeL1AbSize();
k0_ = ComputeK0ForABpingpong(l1AbSize);
kLoop_ = CeilDiv(k_, k0_);
}
uint32_t PpMatmulTilingApi::ComputeK0ForOnlyBpingpong(uint32_t l1AbSize)
{
enLoadAllAmat_ = true;
b0matPingPongBufferLen_ = static_cast<uint32_t>(
static_cast<float>((l1AbSize - RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) / DIM_2));
uint32_t k0MaxB0 =
static_cast<uint32_t>(static_cast<float>(b0matPingPongBufferLen_ / (RoundUp(n0_, CONST_16) * inDataSize_)));
uint32_t k0B0 = k0MaxB0 < CONST_512 ? RoundDown(k0MaxB0, CONST_32) : RoundDown(k0MaxB0, CONST_512);
return k0B0 > CONST_512 ? RoundDown(k0B0, CONST_512) : k0B0;
}
bool PpMatmulTilingApi::IsLoadAllAmat(uint32_t l1AbSize)
{
return (coreLoop_ > blockDim_) && enDequant_ && (kLoop_ > 1) &&
(l1AbSize > RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) && (mLoop_ == 1);
}
uint32_t PpMatmulTilingApi::ComputeK0ForABpingpong(uint32_t l1AbSize)
{
uint32_t k0Max = static_cast<uint32_t>(static_cast<float>(l1AbSize / DIM_2) / ((m0_ + n0_) * inDataSize_));
uint32_t tmpK0;
if (enDequant_) {
tmpK0 = k0Max < CONST_512 ? RoundDown(k0Max, CONST_32) : RoundDown(k0Max, CONST_512);
} else {
tmpK0 = k0Max < CONST_256 ? RoundDown(k0Max, CONST_16) : RoundDown(k0Max, CONST_256);
}
if (tmpK0 > CONST_512) {
tmpK0 = RoundDown(tmpK0, CONST_512);
}
return tmpK0;
}
uint32_t PpMatmulTilingApi::ComputeL1AbSize()
{
if (enDequant_ && deqOnTheFly_) {
return L1_BUFFER_SIZE;
}
return enDequant_ ? (L1_BUFFER_SIZE - L1_BIAS_SIZE - L1_SCALE_SIZE) : L1_BUFFER_SIZE;
}
float PpMatmulTilingApi::GetCost(const uint32_t m0, const uint32_t n0)
{
float aCoef = 1.0;
float bCoef = 1.0;
float bwCoef = 5.0;
uint32_t mLoop = CeilDiv(m_, m0);
uint32_t nLoop = CeilDiv(n_, n0);
if (mLoop == 0 || nLoop == 0) {
return __FLT_MAX__;
}
uint32_t rqdNumCore = numBatch_ * mLoop * nLoop;
uint32_t blockDim = Min(rqdNumCore, platformInfo_.coreNumAic);
uint32_t mOnce = blockDim < nLoop ? m0 : blockDim / nLoop * m0;
uint32_t nOnce = blockDim < nLoop ? platformInfo_.coreNumAic * n0 : n_;
if (mOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
aCoef = bwCoef;
}
if (nOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
bCoef = bwCoef;
}
if (transA_ && m0 % CONST_256 == 0) {
aCoef *= NUM2;
}
if (!transB_ && n0 % CONST_256 == 0) {
bCoef *= NUM2;
}
return 1 / (aCoef * static_cast<float>(n0)) + 1 / (bCoef * static_cast<float>(m0));
}
void PpMatmulTilingApi::UpdateTileSize(const uint32_t m0, const uint32_t n0)
{
m0_ = m0;
n0_ = n0;
mLoop_ = CeilDiv(m_, m0_);
nLoop_ = CeilDiv(n_, n0_);
coreLoop_ = numBatch_ * mLoop_ * nLoop_;
const uint32_t maxNumCubeCore = platformInfo_.coreNumAic;
if (mLoop_ == 1 && transB_ && coreLoop_ % maxNumCubeCore < maxNumCubeCore / NUM4 * NUM3) {
uint32_t tmpM0 = RoundUp(m_, CONST_16);
uint32_t maxN0 = L0C_SIZE / (tmpM0 * sizeof(float));
if (enDequant_) {
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
}
uint32_t x = CeilDiv(n_, maxNumCubeCore);
uint32_t y = CeilDiv(x, maxN0);
uint32_t tmpN0 = RoundUp(CeilDiv(x, y), CONST_16);
uint32_t rqdL0cSize = tmpM0 * tmpN0 * sizeof(float);
if (rqdL0cSize < L0C_SIZE && (tmpM0 + tmpN0) * CONST_256 * inDataSize_ < L1_BUFFER_SIZE) {
m0_ = tmpM0;
n0_ = tmpN0;
nLoop_ = CeilDiv(n_, n0_);
coreLoop_ = numBatch_ * nLoop_;
}
}
blockDim_ = Min(coreLoop_, maxNumCubeCore);
}
void PpMatmulTilingApi::Swizzle()
{
float minCost = m_ * k_ + k_ * n_;
for (uint32_t i = 1; i <= blockDim_; ++i) {
int c = static_cast<int32_t>((blockDim_ + i - 1) / i);
float cost;
// B0 + A < A0 + B
if (i * n0_ + m_ < m0_ * c + n_) {
swizzleDirect_ = 1; // Nz
cost = n0_ * i + m0_ * c;
if (cost <= minCost) {
minCost = cost;
swizzleCount_ = i;
}
} else {
swizzleDirect_ = 0; // Zn
cost = m0_ * i + n0_ * c;
if (cost < minCost) {
minCost = cost;
swizzleCount_ = i;
}
}
}
}
class MlaPreprocessTiling
{
public:
MlaPreprocessTiling(struct PlatformInfo &platformInfo, struct OpParam &opParam, MlaTilingData *tilingData)
{
this->tilingData = tilingData;
this->platformInfo = platformInfo;
this->opParam = opParam;
}
void Init();
void RmsNormQuantTiling();
void RopeConcatTiling();
void EinSumQuantTiling();
void SetTilingKey();
void SetMlapoWorkSpace();
private:
MlaTilingData *tilingData;
struct PlatformInfo platformInfo;
struct OpParam opParam;
};
void MlaPreprocessTiling::RmsNormQuantTiling()
{
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE;
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
tilingData->rmsNumRow2 = opParam.N;
tilingData->rmsQuantMin2 = -CONST_128;
}
void MlaPreprocessTiling::RopeConcatTiling()
{
uint32_t ntokens = opParam.N;
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
uint32_t headDim = HEADDIM;
uint32_t headNumQ = hiddenSizeQ / headDim;
uint32_t concatSize = CONCAT_SIZE;
uint32_t maxCore = platformInfo.coreNumAiv;
uint32_t maxUbSize = platformInfo.ubSize;
uint32_t allHeadNum = ntokens * headNumQ;
uint32_t tempCore = (allHeadNum + maxCore - 1) / maxCore;
uint32_t realCore = (allHeadNum + tempCore - 1) / tempCore; // Actual number of the core for operation
uint32_t nlCoreRun = (allHeadNum + realCore - 1) / realCore; // The number of heads in the front core
uint32_t lCoreRun = allHeadNum - (realCore - 1) * nlCoreRun; // The number of heads in the tail core
uint32_t dataTypeSize = 2;
// Calculate how many lines can be moved at a time. q 4+2、reverseq 4、neg 4、sin 4+2、cos 4+2 + concat 2
uint32_t allSize =
headDim * (3 * (4 + dataTypeSize) + 2 * 4) + concatSize * dataTypeSize; // lift precision calculation of ROPE
uint32_t maxNPerLoopForUb = maxUbSize / allSize; // the maximum number of rows at a time for UB
uint32_t preCoreLoopTime = (nlCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of front core
uint32_t preCoreLoopNLast =
nlCoreRun -
(preCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the front core
uint32_t lastCoreLoopTime = (lCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of tail core
uint32_t lastCoreLoopNLast =
lCoreRun -
(lastCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the tail core
tilingData->hiddenSizeQ = hiddenSizeQ;
tilingData->headNumQ = headNumQ;
tilingData->headDim = headDim;
tilingData->concatSize = concatSize;
tilingData->rotaryCoeff = NUM2;
tilingData->ntokens = ntokens;
tilingData->realCore = realCore;
tilingData->nlCoreRun = nlCoreRun;
tilingData->lCoreRun = nlCoreRun;
tilingData->maxNPerLoopForUb = maxNPerLoopForUb;
tilingData->preCoreLoopTime = preCoreLoopTime;
tilingData->preCoreLoopNLast = preCoreLoopNLast;
tilingData->lastCoreLoopTime = lastCoreLoopTime;
tilingData->lastCoreLoopNLast = lastCoreLoopNLast;
}
void MlaPreprocessTiling::EinSumQuantTiling()
{
uint32_t aivCore = platformInfo.coreNumAiv;
uint32_t ubSize = UB_SIZE - 1024;
// input shape
uint32_t esqBatch = opParam.N; // tokenNum
uint32_t esqHeadNum = opParam.headNum; // headNum
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
// split core
uint32_t esqFrontCore = esqBatch % aivCore;
uint32_t esqTailCore = aivCore - esqFrontCore;
uint32_t esqFrontCoreBatch = CeilDiv(esqBatch, aivCore);
uint32_t esqTailCoreBatch = esqBatch / aivCore;
// split ub --> calc H' <-- The number of rows handled in a UB cycle.
uint32_t splitFactor = 0;
uint32_t esqHeadPerLoop = 0; // The number of head rows per UB calculation
uint32_t repeatMask = 0;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
// Move scales in at once, broadcast, and cache them all H * 32bytes
uint32_t scaleUb = RoundUp(esqHeadNum) * CONST_32;
// bf16 input [H', colNum](f16 + fp32 + int8), ub reuse
splitFactor = esqColNum * (sizeof(uint16_t) + sizeof(float) + sizeof(uint8_t));
splitFactor *= NUM2;
esqHeadPerLoop = (ubSize - scaleUb) / splitFactor; // 26
repeatMask = FP32_REPEAT_MASK;
} else {
// fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16)
splitFactor =
esqColNum * (NUM2 * sizeof(uint16_t) + sizeof(uint8_t)) + sizeof(uint16_t) + (CONST_16 * sizeof(uint16_t));
esqHeadPerLoop = ubSize / splitFactor;
repeatMask = FP16_REPEAT_MASK;
esqHeadPerLoop = RoundDown(esqHeadPerLoop);
}
uint32_t esqUbHeadLoop = esqHeadNum / esqHeadPerLoop; // UB complete cycles
uint32_t esqHeadTail = esqHeadNum % esqHeadPerLoop; // The number of rows that UB last processed the head.
uint32_t esqColLoop = esqColNum / repeatMask; // Each row counts the number of times to cycle through columns.
uint32_t esqColTail =
esqColNum % repeatMask; // colNum is not 64/128 aligned, the number of columns is calculated last.
tilingData->esqFrontCore = esqFrontCore;
tilingData->esqTailCore = esqTailCore;
tilingData->esqFrontCoreBatch = esqFrontCoreBatch;
tilingData->esqTailCoreBatch = esqTailCoreBatch;
tilingData->esqHeadNum = esqHeadNum;
tilingData->esqColNum = esqColNum;
tilingData->esqUbHeadLoop = esqUbHeadLoop;
tilingData->esqHeadPerLoop = esqHeadPerLoop;
tilingData->esqHeadTail = esqHeadTail;
tilingData->esqColLoop = esqColLoop;
tilingData->esqColTail = esqColTail;
}
void MlaPreprocessTiling::SetMlapoWorkSpace()
{
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
uint64_t maxWorkspaceSize = workSizeS1;
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS3);
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS4);
maxWorkspaceSize *= static_cast<uint64_t>(opParam.N);
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
uint64_t userWorkspaceSize;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
} else {
userWorkspaceSize = 3 * maxWorkspaceSize;
}
tilingData->userWorkspaceSize = userWorkspaceSize;
tilingData->s1Offset = 0;
tilingData->s2Offset = tilingData->s1Offset + maxWorkspaceSize;
tilingData->s3Offset = tilingData->s2Offset + maxWorkspaceSize;
tilingData->s4Offset = tilingData->s3Offset + maxWorkspaceSize;
tilingData->s5Offset = tilingData->s4Offset + maxWorkspaceSize;
}
void MlaPreprocessTiling::SetTilingKey()
{
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
tilingData->tilingKey = tilingKey;
}
void MlaPreprocessTiling::Init()
{
tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N;
bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
deqOnTheFly = true;
}
PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
mm1TilingApi.GetTilingData(tilingData->mm1);
PpMatmulTilingApi mm2TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE_RMS, // k
opParam.headNum * HIDDEN_STRATE_ROPE, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
mm2TilingApi.GetTilingData(tilingData->mm2);
PpMatmulTilingApi mm3TilingApi(platformInfo,
opParam.headNum, // numBatch
opParam.N, // m
CONST_128, // k
CONCAT_SIZE, // n
false, // transA
false, // transB
false, // enDequant
deqOnTheFly); // in bf16.cce?
mm3TilingApi.GetTilingData(tilingData->mm3);
RmsNormQuantTiling();
RopeConcatTiling();
EinSumQuantTiling();
SetMlapoWorkSpace();
SetTilingKey();
return;
}
std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
{"krope_ctkv", 1}, {"int8_nzcache", 2}, {"nzcache", 3}};
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
{"per_tensor_quant_asymm", 0},
{"per_token_quant_symm", 1},
};
template <typename MapType>
inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
const char *mode_name)
{
c10::string_view mode_str = mode_opt.value_or(default_mode);
auto it = mode_map.find(mode_str);
TORCH_CHECK(it != mode_map.end(), "Unsupported ", mode_name, " value: '", mode_str, "'");
return it->second;
}
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode
)
{
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
auto quantMode = get_op_mode(quant_mode_map, quant_mode, "per_token_quant_symm", "quant_mode");
platform_ascendc::PlatformAscendC *platformAscendC = platform_ascendc::PlatformAscendCManager::GetInstance();
struct PlatformInfo platformInfo;
platformInfo.coreNum = platformAscendC->GetCoreNum();
platformInfo.coreNumAic = platformAscendC->GetCoreNumAic();
platformInfo.coreNumAiv = platformAscendC->GetCoreNumAiv();
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformInfo.ubSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L1, platformInfo.l1Size);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L2, platformInfo.l2Size);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, platformInfo.l0aSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, platformInfo.l0bSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, platformInfo.l0cSize);
int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0];
OpParam opParam;
opParam.N = N;
opParam.headNum = headNum;
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);
mlaTiling.Init();
uint32_t blockDim = platformInfo.coreNumAic;
// workspace
uint64_t system_workspace_size = static_cast<uint64_t>(platformAscendC->GetLibApiWorkSpaceSize());
uint64_t workspace_size = system_workspace_size + tilingData.userWorkspaceSize;
auto options = at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device());
auto workspace_tensor = at::empty({static_cast<int64_t>(workspace_size)}, options);
// tiling
int32_t bIndex = N - 1;
uint32_t tilingSize = sizeof(MlaTilingData);
static auto global_tiling_data =
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS},
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()));
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize,
ACL_MEMCPY_HOST_TO_DEVICE);
} else {
// Handle the case where bIndex is out of range
TORCH_CHECK(false, "bIndex is out of range: ", bIndex);
}
at::Tensor tiling = at::from_blob(
global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex),
tilingSize,
at::kByte);
return std::make_tuple(workspace_tensor, tiling, blockDim);
}
} // namespace npu_kernel