add mla_preprocess kernel (#3226)

### What this PR does / why we need it?

- Adds the `mla_preprocess` custom kernel to provide an optimized
pre-processing operator for Multi-head Latent Attention (MLA) on Ascend
NPUs.
- Wires the new kernel into the C++ extension pipeline so vLLM can
invoke it directly, cutting Python-side tensor shuffling and memory
copies that previously bottlenecked MLA compilation paths.

### Does this PR introduce any user-facing change?

- No. The change only introduces a low-level kernel; public APIs and
inference behavior remain unchanged.

### How was this patch tested?

- Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.

- vLLM version: v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-10-12 07:39:45 +08:00
committed by GitHub
parent 1b1207e3c3
commit bcc313e8f2
32 changed files with 9158 additions and 3 deletions

View File

@ -12,8 +12,8 @@ repos:
- id: codespell
args: [
--toml, pyproject.toml,
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn'
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
]
additional_dependencies:
- tomli
@ -35,6 +35,10 @@ repos:
rev: v1.32.0
hooks:
- id: typos
args: [
"--force-exclude",
"--exclude", "csrc/mla_preprocess/**"
]
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:

View File

@ -44,11 +44,13 @@ else()
endif()
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
file(GLOB KERNEL_FILES
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
ascendc_library(vllm_ascend_kernels SHARED
${KERNEL_FILES}
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
)
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
@ -90,7 +92,11 @@ target_link_libraries(
libtorch_npu.so
vllm_ascend_kernels
ascendcl
tiling_api
register
platform
ascendalog
dl
)
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")

View File

@ -0,0 +1,698 @@
// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost.git
// https://gitee.com/ascend/op-plugin.git
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#include <fstream>
#include <iostream>
#include <math.h>
#include <stdexcept>
#include "acl/acl.h"
// #include "defines.h"
// #include "torch_helper.h"
#include "tiling/platform/platform_ascendc.h"
#include "tiling/mla_preprocess_tiling.h"
// #include "aclrtlaunch_mla_preprocess.h"
// namespace sglang {
namespace mlapo {
constexpr uint32_t DIM_2 = 2;
constexpr uint32_t AXES_ALIGN_SIZE = 512;
constexpr uint32_t BASE_BLOCK_STEP = 2;
constexpr uint32_t CONST_16 = 16;
constexpr uint32_t CONST_32 = 32;
constexpr uint32_t CONST_128 = 128;
constexpr uint32_t CONST_256 = 256;
constexpr uint32_t CONST_512 = 512;
constexpr uint32_t L1_BUFFER_SIZE = 524288;
constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 262144;
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN = 131072;
constexpr uint32_t L1_SCALE_SIZE = 4096;
constexpr uint32_t L1_BIAS_SIZE = 2048;
constexpr uint32_t L0C_SIZE = 128 * 1024;
constexpr uint32_t CONCAT_SIZE = 512;
constexpr uint32_t HIDDEN_STRATE = 7168;
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
constexpr uint32_t UB_SIZE = 196352;
constexpr uint32_t HEADDIM = 64;
constexpr uint32_t FP32_REPEAT_MASK = 64;
constexpr uint32_t FP16_REPEAT_MASK = 128;
constexpr int32_t NUM1 = 1;
constexpr int32_t NUM2 = 2;
constexpr int32_t NUM3 = 3;
constexpr int32_t NUM4 = 4;
constexpr int32_t NUM8 = 8;
constexpr uint32_t INDEX_WDQKV = 5;
constexpr uint32_t INDEX_WUQ = 18;
constexpr uint32_t INDEX_WUK = 20;
constexpr uint32_t MAX_SUPPORT_TOKEN_NUMS = 1024;
inline uint32_t CeilDiv(const uint32_t dividend, const uint32_t divisor)
{
if (divisor == 0) {
return UINT32_MAX;
}
return (dividend + divisor - 1) / divisor;
}
inline uint32_t RoundUp(const uint32_t val, const uint32_t align = 16)
{
if (align == 0) {
return 0;
}
return (val + align - 1) / align * align;
}
inline uint32_t RoundDown(const uint32_t val, const uint32_t align = 16)
{
if (align == 0) {
return 0;
}
return val / align * align;
}
template <typename T = uint32_t>
inline T Max(const T a, const T b)
{
return a > b ? a : b;
}
template <typename T = uint32_t>
inline T Min(const T a, const T b)
{
return a < b ? a : b;
}
struct MlaPreprocess {
enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,
PER_TOKEN_SYMM_QUANT,
PER_TOKEN_ASYMM_QUANT,
NO_QUANT
};
};
using QuantMode = MlaPreprocess::QuantMode;
struct PlatformInfo {
uint32_t coreNum;
uint32_t coreNumAic;
uint32_t coreNumAiv;
uint64_t ubSize;
uint64_t l1Size;
uint64_t l2Size;
uint64_t l0aSize;
uint64_t l0bSize;
uint64_t l0cSize;
};
struct OpParam {
uint32_t N;
uint32_t headNum;
int32_t cacheMode;
QuantMode quantMode;
caffe2::TypeMeta inDtype;
};
class PpMatmulTilingApi
{
public:
PpMatmulTilingApi(struct PlatformInfo &platformInfo, uint32_t numBatch, uint32_t m, uint32_t k, uint32_t n,
bool transA, bool transB, bool enDequant, bool deqOnTheFly)
: platformInfo_(platformInfo),
numBatch_(numBatch),
m_(m),
k_(k),
n_(n),
transA_(transA),
transB_(transB),
enDequant_(enDequant),
deqOnTheFly_(deqOnTheFly)
{
inDataSize_ = enDequant ? sizeof(uint8_t) : sizeof(uint16_t);
}
void GetTilingData(PpMatmulTilingData &tiling);
private:
void GetTileSize();
float GetCost(const uint32_t m0, const uint32_t n0);
void UpdateTileSize(const uint32_t m0, const uint32_t n0);
void Swizzle();
uint32_t ComputeL1AbSize();
uint32_t ComputeK0ForABpingpong(uint32_t l1AbSize);
bool IsLoadAllAmat(uint32_t l1AbSize);
uint32_t ComputeK0ForOnlyBpingpong(uint32_t l1AbSize);
private:
uint32_t numBatch_{0};
uint32_t m_{0};
uint32_t k_{0};
uint32_t n_{0};
uint32_t m0_{0};
uint32_t k0_{0};
uint32_t n0_{0};
uint32_t mLoop_{0};
uint32_t kLoop_{0};
uint32_t nLoop_{0};
uint32_t coreLoop_{0};
uint32_t swizzleCount_{0};
uint32_t blockDim_{0};
uint32_t swizzleDirect_{0};
uint32_t inDataSize_{0};
uint32_t b0matPingPongBufferLen_{L1_PINGPONG_BUFFER_LEN};
bool transA_{false};
bool transB_{false};
bool enDequant_{false};
bool enShuffleK_{false};
bool enLoadAllAmat_{false};
bool deqOnTheFly_{false};
struct PlatformInfo platformInfo_;
};
void PpMatmulTilingApi::GetTilingData(PpMatmulTilingData &tiling)
{
GetTileSize();
tiling.numBatch = numBatch_;
tiling.m = m_;
tiling.k = k_;
tiling.n = n_;
tiling.m0 = m0_;
tiling.k0 = k0_;
tiling.n0 = n0_;
tiling.mLoop = mLoop_;
tiling.kLoop = kLoop_;
tiling.nLoop = nLoop_;
tiling.coreLoop = coreLoop_;
tiling.swizzleCount = swizzleCount_;
tiling.swizzleDirect = swizzleDirect_;
tiling.enShuffleK = static_cast<uint32_t>(enShuffleK_);
tiling.blockDim = blockDim_;
tiling.enLoadAllAmat = static_cast<uint32_t>(enLoadAllAmat_);
tiling.b0matPingPongBufferLen = b0matPingPongBufferLen_;
}
void PpMatmulTilingApi::GetTileSize()
{
bool priFlag = !(m_ < n_);
uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16;
uint32_t priAxes = RoundUp(priFlag ? m_ : n_, CONST_16);
uint32_t subAxes = RoundUp(priFlag ? n_ : m_, roundBase);
float minCost = __FLT_MAX__;
uint32_t maxAxes0 = AXES_ALIGN_SIZE;
uint32_t maxPriAxes0 = Min(maxAxes0, priAxes);
uint32_t maxSubAxes0 = Min(maxAxes0, subAxes);
for (uint32_t priAxes0 = CONST_16; priAxes0 <= maxPriAxes0; priAxes0 *= BASE_BLOCK_STEP) {
for (uint32_t subAxes0 = CONST_16; subAxes0 <= maxSubAxes0; subAxes0 *= BASE_BLOCK_STEP) {
if (priAxes0 * subAxes0 * sizeof(float) > platformInfo_.l0cSize) {
continue;
}
uint32_t newM0 = priFlag ? priAxes0 : subAxes0;
uint32_t newN0 = priFlag ? subAxes0 : priAxes0;
if (newN0 > CONST_256 && enDequant_) {
continue;
}
float cost = GetCost(newM0, newN0);
if (cost < minCost) {
minCost = cost;
UpdateTileSize(newM0, newN0);
}
}
}
Swizzle();
uint32_t l1AbSize = ComputeL1AbSize();
k0_ = ComputeK0ForABpingpong(l1AbSize);
kLoop_ = CeilDiv(k_, k0_);
}
uint32_t PpMatmulTilingApi::ComputeK0ForOnlyBpingpong(uint32_t l1AbSize)
{
enLoadAllAmat_ = true;
b0matPingPongBufferLen_ = static_cast<uint32_t>(
static_cast<float>((l1AbSize - RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) / DIM_2));
uint32_t k0MaxB0 =
static_cast<uint32_t>(static_cast<float>(b0matPingPongBufferLen_ / (RoundUp(n0_, CONST_16) * inDataSize_)));
uint32_t k0B0 = k0MaxB0 < CONST_512 ? RoundDown(k0MaxB0, CONST_32) : RoundDown(k0MaxB0, CONST_512);
return k0B0 > CONST_512 ? RoundDown(k0B0, CONST_512) : k0B0;
}
bool PpMatmulTilingApi::IsLoadAllAmat(uint32_t l1AbSize)
{
return (coreLoop_ > blockDim_) && enDequant_ && (kLoop_ > 1) &&
(l1AbSize > RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) && (mLoop_ == 1);
}
uint32_t PpMatmulTilingApi::ComputeK0ForABpingpong(uint32_t l1AbSize)
{
uint32_t k0Max = static_cast<uint32_t>(static_cast<float>(l1AbSize / DIM_2) / ((m0_ + n0_) * inDataSize_));
uint32_t tmpK0;
if (enDequant_) {
tmpK0 = k0Max < CONST_512 ? RoundDown(k0Max, CONST_32) : RoundDown(k0Max, CONST_512);
} else {
tmpK0 = k0Max < CONST_256 ? RoundDown(k0Max, CONST_16) : RoundDown(k0Max, CONST_256);
}
if (tmpK0 > CONST_512) {
tmpK0 = RoundDown(tmpK0, CONST_512);
}
return tmpK0;
}
uint32_t PpMatmulTilingApi::ComputeL1AbSize()
{
if (enDequant_ && deqOnTheFly_) {
return L1_BUFFER_SIZE;
}
return enDequant_ ? (L1_BUFFER_SIZE - L1_BIAS_SIZE - L1_SCALE_SIZE) : L1_BUFFER_SIZE;
}
float PpMatmulTilingApi::GetCost(const uint32_t m0, const uint32_t n0)
{
float aCoef = 1.0;
float bCoef = 1.0;
float bwCoef = 5.0;
uint32_t mLoop = CeilDiv(m_, m0);
uint32_t nLoop = CeilDiv(n_, n0);
if (mLoop == 0 || nLoop == 0) {
return __FLT_MAX__;
}
uint32_t rqdNumCore = numBatch_ * mLoop * nLoop;
uint32_t blockDim = Min(rqdNumCore, platformInfo_.coreNumAic);
uint32_t mOnce = blockDim < nLoop ? m0 : blockDim / nLoop * m0;
uint32_t nOnce = blockDim < nLoop ? platformInfo_.coreNumAic * n0 : n_;
if (mOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
aCoef = bwCoef;
}
if (nOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
bCoef = bwCoef;
}
if (transA_ && m0 % CONST_256 == 0) {
aCoef *= NUM2;
}
if (!transB_ && n0 % CONST_256 == 0) {
bCoef *= NUM2;
}
return 1 / (aCoef * static_cast<float>(n0)) + 1 / (bCoef * static_cast<float>(m0));
}
void PpMatmulTilingApi::UpdateTileSize(const uint32_t m0, const uint32_t n0)
{
m0_ = m0;
n0_ = n0;
mLoop_ = CeilDiv(m_, m0_);
nLoop_ = CeilDiv(n_, n0_);
coreLoop_ = numBatch_ * mLoop_ * nLoop_;
const uint32_t maxNumCubeCore = platformInfo_.coreNumAic;
if (mLoop_ == 1 && transB_ && coreLoop_ % maxNumCubeCore < maxNumCubeCore / NUM4 * NUM3) {
uint32_t tmpM0 = RoundUp(m_, CONST_16);
uint32_t maxN0 = L0C_SIZE / (tmpM0 * sizeof(float));
if (enDequant_) {
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
}
uint32_t x = CeilDiv(n_, maxNumCubeCore);
uint32_t y = CeilDiv(x, maxN0);
uint32_t tmpN0 = RoundUp(CeilDiv(x, y), CONST_16);
uint32_t rqdL0cSize = tmpM0 * tmpN0 * sizeof(float);
if (rqdL0cSize < L0C_SIZE && (tmpM0 + tmpN0) * CONST_256 * inDataSize_ < L1_BUFFER_SIZE) {
m0_ = tmpM0;
n0_ = tmpN0;
nLoop_ = CeilDiv(n_, n0_);
coreLoop_ = numBatch_ * nLoop_;
}
}
blockDim_ = Min(coreLoop_, maxNumCubeCore);
}
void PpMatmulTilingApi::Swizzle()
{
float minCost = m_ * k_ + k_ * n_;
for (uint32_t i = 1; i <= blockDim_; ++i) {
int c = static_cast<int32_t>((blockDim_ + i - 1) / i);
float cost;
// B0 + A < A0 + B
if (i * n0_ + m_ < m0_ * c + n_) {
swizzleDirect_ = 1; // Nz
cost = n0_ * i + m0_ * c;
if (cost <= minCost) {
minCost = cost;
swizzleCount_ = i;
}
} else {
swizzleDirect_ = 0; // Zn
cost = m0_ * i + n0_ * c;
if (cost < minCost) {
minCost = cost;
swizzleCount_ = i;
}
}
}
}
class MlaPreprocessTiling
{
public:
MlaPreprocessTiling(struct PlatformInfo &platformInfo, struct OpParam &opParam, MlaTilingData *tilingData)
{
this->tilingData = tilingData;
this->platformInfo = platformInfo;
this->opParam = opParam;
}
void Init();
void RmsNormQuantTiling();
void RopeConcatTiling();
void EinSumQuantTiling();
void SetTilingKey();
void SetMlapoWorkSpace();
private:
MlaTilingData *tilingData;
struct PlatformInfo platformInfo;
struct OpParam opParam;
};
void MlaPreprocessTiling::RmsNormQuantTiling()
{
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
tilingData->rmsNumCol1 = HIDDEN_STRATE;
tilingData->rmsNumRow1 = opParam.N;
tilingData->rmsQuantMin1 = -CONST_128;
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
tilingData->rmsNumRow2 = opParam.N;
tilingData->rmsQuantMin2 = -CONST_128;
}
void MlaPreprocessTiling::RopeConcatTiling()
{
uint32_t ntokens = opParam.N;
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
uint32_t headDim = HEADDIM;
uint32_t headNumQ = hiddenSizeQ / headDim;
uint32_t concatSize = CONCAT_SIZE;
uint32_t maxCore = platformInfo.coreNumAiv;
uint32_t maxUbSize = platformInfo.ubSize;
uint32_t allHeadNum = ntokens * headNumQ;
uint32_t tempCore = (allHeadNum + maxCore - 1) / maxCore;
uint32_t realCore = (allHeadNum + tempCore - 1) / tempCore; // Actual number of the core for operation
uint32_t nlCoreRun = (allHeadNum + realCore - 1) / realCore; // The number of heads in the front core
uint32_t lCoreRun = allHeadNum - (realCore - 1) * nlCoreRun; // The number of heads in the tail core
uint32_t dataTypeSize = 2;
// Calculate how many lines can be moved at a time. q 4+2、reverseq 4、neg 4、sin 4+2、cos 4+2 + concat 2
uint32_t allSize =
headDim * (3 * (4 + dataTypeSize) + 2 * 4) + concatSize * dataTypeSize; // lift precision calculation of ROPE
uint32_t maxNPerLoopForUb = maxUbSize / allSize; // the maximum number of rows at a time for UB
uint32_t preCoreLoopTime = (nlCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of front core
uint32_t preCoreLoopNLast =
nlCoreRun -
(preCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the front core
uint32_t lastCoreLoopTime = (lCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of tail core
uint32_t lastCoreLoopNLast =
lCoreRun -
(lastCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the tail core
tilingData->hiddenSizeQ = hiddenSizeQ;
tilingData->headNumQ = headNumQ;
tilingData->headDim = headDim;
tilingData->concatSize = concatSize;
tilingData->rotaryCoeff = NUM2;
tilingData->ntokens = ntokens;
tilingData->realCore = realCore;
tilingData->nlCoreRun = nlCoreRun;
tilingData->lCoreRun = nlCoreRun;
tilingData->maxNPerLoopForUb = maxNPerLoopForUb;
tilingData->preCoreLoopTime = preCoreLoopTime;
tilingData->preCoreLoopNLast = preCoreLoopNLast;
tilingData->lastCoreLoopTime = lastCoreLoopTime;
tilingData->lastCoreLoopNLast = lastCoreLoopNLast;
}
void MlaPreprocessTiling::EinSumQuantTiling()
{
uint32_t aivCore = platformInfo.coreNumAiv;
uint32_t ubSize = UB_SIZE - 1024;
// input shape
uint32_t esqBatch = opParam.N; // tokenNum
uint32_t esqHeadNum = opParam.headNum; // headNum
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
// split core
uint32_t esqFrontCore = esqBatch % aivCore;
uint32_t esqTailCore = aivCore - esqFrontCore;
uint32_t esqFrontCoreBatch = CeilDiv(esqBatch, aivCore);
uint32_t esqTailCoreBatch = esqBatch / aivCore;
// split ub --> calc H' <-- The number of rows handled in a UB cycle.
uint32_t splitFactor = 0;
uint32_t esqHeadPerLoop = 0; // The number of head rows per UB calculation
uint32_t repeatMask = 0;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
// Move scales in at once, broadcast, and cache them all H * 32bytes
uint32_t scaleUb = RoundUp(esqHeadNum) * CONST_32;
// bf16 input [H', colNum](f16 + fp32 + int8), ub reuse
splitFactor = esqColNum * (sizeof(uint16_t) + sizeof(float) + sizeof(uint8_t));
splitFactor *= NUM2;
esqHeadPerLoop = (ubSize - scaleUb) / splitFactor; // 26
repeatMask = FP32_REPEAT_MASK;
} else {
// fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16)
splitFactor =
esqColNum * (NUM2 * sizeof(uint16_t) + sizeof(uint8_t)) + sizeof(uint16_t) + (CONST_16 * sizeof(uint16_t));
esqHeadPerLoop = ubSize / splitFactor;
repeatMask = FP16_REPEAT_MASK;
esqHeadPerLoop = RoundDown(esqHeadPerLoop);
}
uint32_t esqUbHeadLoop = esqHeadNum / esqHeadPerLoop; // UB complete cycles
uint32_t esqHeadTail = esqHeadNum % esqHeadPerLoop; // The number of rows that UB last processed the head.
uint32_t esqColLoop = esqColNum / repeatMask; // Each row counts the number of times to cycle through columns.
uint32_t esqColTail =
esqColNum % repeatMask; // colNum is not 64/128 aligned, the number of columns is calculated last.
tilingData->esqFrontCore = esqFrontCore;
tilingData->esqTailCore = esqTailCore;
tilingData->esqFrontCoreBatch = esqFrontCoreBatch;
tilingData->esqTailCoreBatch = esqTailCoreBatch;
tilingData->esqHeadNum = esqHeadNum;
tilingData->esqColNum = esqColNum;
tilingData->esqUbHeadLoop = esqUbHeadLoop;
tilingData->esqHeadPerLoop = esqHeadPerLoop;
tilingData->esqHeadTail = esqHeadTail;
tilingData->esqColLoop = esqColLoop;
tilingData->esqColTail = esqColTail;
}
void MlaPreprocessTiling::SetMlapoWorkSpace()
{
uint64_t s1wsFactor =
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
: HIDDEN_STRATE * sizeof(int8_t));
uint64_t workSizeS1 = s1wsFactor;
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
uint64_t maxWorkspaceSize = workSizeS1;
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS3);
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS4);
maxWorkspaceSize *= static_cast<uint64_t>(opParam.N);
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
uint64_t userWorkspaceSize;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
} else {
userWorkspaceSize = 3 * maxWorkspaceSize;
}
tilingData->userWorkspaceSize = userWorkspaceSize;
tilingData->s1Offset = 0;
tilingData->s2Offset = tilingData->s1Offset + maxWorkspaceSize;
tilingData->s3Offset = tilingData->s2Offset + maxWorkspaceSize;
tilingData->s4Offset = tilingData->s3Offset + maxWorkspaceSize;
tilingData->s5Offset = tilingData->s4Offset + maxWorkspaceSize;
}
void MlaPreprocessTiling::SetTilingKey()
{
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
tilingData->tilingKey = tilingKey;
}
void MlaPreprocessTiling::Init()
{
tilingData->numCore = platformInfo.coreNumAic;
tilingData->n = opParam.N;
bool deqOnTheFly = false;
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
deqOnTheFly = true;
}
PpMatmulTilingApi mm1TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE, // k
HIDDEN_STRATE_MM, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
mm1TilingApi.GetTilingData(tilingData->mm1);
PpMatmulTilingApi mm2TilingApi(platformInfo,
1, // numBatch
opParam.N, // m
HIDDEN_STRATE_RMS, // k
opParam.headNum * HIDDEN_STRATE_ROPE, // n
false, // transA
true, // transB
true, // enDequant
deqOnTheFly); // in bf16.cce?
mm2TilingApi.GetTilingData(tilingData->mm2);
PpMatmulTilingApi mm3TilingApi(platformInfo,
opParam.headNum, // numBatch
opParam.N, // m
CONST_128, // k
CONCAT_SIZE, // n
false, // transA
false, // transB
false, // enDequant
deqOnTheFly); // in bf16.cce?
mm3TilingApi.GetTilingData(tilingData->mm3);
RmsNormQuantTiling();
RopeConcatTiling();
EinSumQuantTiling();
SetMlapoWorkSpace();
SetTilingKey();
return;
}
std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
{"krope_ctkv", 1}, {"int8_nzcache", 2}, {"nzcache", 3}};
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
{"per_tensor_quant_asymm", 0},
{"per_token_quant_symm", 1},
};
template <typename MapType>
inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
const char *mode_name)
{
c10::string_view mode_str = mode_opt.value_or(default_mode);
auto it = mode_map.find(mode_str);
TORCH_CHECK(it != mode_map.end(), "Unsupported ", mode_name, " value: '", mode_str, "'");
return it->second;
}
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
const at::Tensor &hiddenState,
const at::Tensor &wuk,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode
)
{
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
auto quantMode = get_op_mode(quant_mode_map, quant_mode, "per_token_quant_symm", "quant_mode");
platform_ascendc::PlatformAscendC *platformAscendC = platform_ascendc::PlatformAscendCManager::GetInstance();
struct PlatformInfo platformInfo;
platformInfo.coreNum = platformAscendC->GetCoreNum();
platformInfo.coreNumAic = platformAscendC->GetCoreNumAic();
platformInfo.coreNumAiv = platformAscendC->GetCoreNumAiv();
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformInfo.ubSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L1, platformInfo.l1Size);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L2, platformInfo.l2Size);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, platformInfo.l0aSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, platformInfo.l0bSize);
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, platformInfo.l0cSize);
int32_t N = hiddenState.sizes()[0];
int32_t headNum = wuk.sizes()[0];
OpParam opParam;
opParam.N = N;
opParam.headNum = headNum;
opParam.cacheMode = static_cast<int32_t>(cacheMode);
opParam.quantMode = static_cast<QuantMode>(quantMode);
opParam.inDtype = hiddenState.options().dtype();
MlaTilingData tilingData;
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);
mlaTiling.Init();
uint32_t blockDim = platformInfo.coreNumAic;
// workspace
uint64_t system_workspace_size = static_cast<uint64_t>(platformAscendC->GetLibApiWorkSpaceSize());
uint64_t workspace_size = system_workspace_size + tilingData.userWorkspaceSize;
auto options = at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device());
auto workspace_tensor = at::empty({static_cast<int64_t>(workspace_size)}, options);
// tiling
int32_t bIndex = N - 1;
uint32_t tilingSize = sizeof(MlaTilingData);
static auto global_tiling_data =
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS},
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()));
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) {
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize,
ACL_MEMCPY_HOST_TO_DEVICE);
} else {
// Handle the case where bIndex is out of range
TORCH_CHECK(false, "bIndex is out of range: ", bIndex);
}
at::Tensor tiling = at::from_blob(
global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex),
tilingSize,
at::kByte);
return std::make_tuple(workspace_tensor, tiling, blockDim);
}
} // namespace npu_kernel

View File

@ -0,0 +1,95 @@
// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#ifndef MLAPREPROCESS_TILING_H
#define MLAPREPROCESS_TILING_H
#include <cstdint>
struct PpMatmulTilingData {
uint32_t numBatch{0};
uint32_t m{0};
uint32_t k{0};
uint32_t n{0};
uint32_t m0{0};
uint32_t k0{0};
uint32_t n0{0};
uint32_t mLoop{0};
uint32_t kLoop{0};
uint32_t nLoop{0};
uint32_t coreLoop{0};
uint32_t swizzleCount{0};
uint32_t swizzleDirect{0};
uint32_t enShuffleK{0};
uint32_t blockDim{0};
uint32_t enLoadAllAmat{0};
uint32_t b0matPingPongBufferLen{0};
};
struct MlaTilingData {
uint32_t tilingKey{0};
uint64_t userWorkspaceSize{0};
uint64_t s1Offset{0};
uint64_t s2Offset{0};
uint64_t s3Offset{0};
uint64_t s4Offset{0};
uint64_t s5Offset{0};
uint32_t numCore{0};
uint32_t n{0};
uint32_t perTaskNum{0};
uint32_t resTaskNum{0};
PpMatmulTilingData mm1;
PpMatmulTilingData mm2;
PpMatmulTilingData mm3;
// rms1
uint32_t rmsNumCore1{0};
uint32_t rmsNumCol1{0};
uint32_t rmsNumRow1{0};
uint32_t rmsQuantMin1{0};
// rms2
uint32_t rmsNumCore2{0};
uint32_t rmsNumCol2{0};
uint32_t rmsNumRow2{0};
uint32_t rmsQuantMin2{0};
uint32_t hiddenSizeQ{0};
uint32_t headNumQ{0};
uint32_t headDim{0};
uint32_t concatSize{0};
uint32_t rotaryCoeff{0};
uint32_t ntokens{0};
uint32_t realCore{0};
uint32_t nlCoreRun{0};
uint32_t lCoreRun{0};
uint32_t maxNPerLoopForUb{0};
uint32_t preCoreLoopTime{0};
uint32_t preCoreLoopNLast{0};
uint32_t lastCoreLoopTime{0};
uint32_t lastCoreLoopNLast{0};
// EinSumQuant
uint32_t esqFrontCore{0};
uint32_t esqTailCore{0};
uint32_t esqFrontCoreBatch{0};
uint32_t esqTailCoreBatch{0};
uint32_t esqHeadNum{0};
uint32_t esqColNum{0};
uint32_t esqUbHeadLoop{0};
uint32_t esqHeadPerLoop{0};
uint32_t esqHeadTail{0};
uint32_t esqColLoop{0};
uint32_t esqColTail{0};
};
#endif // MLAPREPROCESS_TILING_H

View File

@ -0,0 +1,25 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_COMMON_H
#define INCLUDE_COMMON_H
#define CONST_2 2
#define SET_FLAG(trigger, waiter, e) AscendC::SetFlag<AscendC::HardEvent::trigger##_##waiter>((e))
#define WAIT_FLAG(trigger, waiter, e) AscendC::WaitFlag<AscendC::HardEvent::trigger##_##waiter>((e))
#define PIPE_BARRIER(pipe) AscendC::PipeBarrier<PIPE_##pipe>()
#ifndef __force_inline__
#define __force_inline__ inline __attribute__((always_inline))
#endif
#endif

View File

@ -0,0 +1,121 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_COMMON_FUNC_H
#define INCLUDE_COMMON_FUNC_H
#include <limits>
#include <type_traits>
#ifdef __CCE_KT_TEST__
#include "stub_def.h"
#include "stub_fun.h"
#else
#include "kernel_macros.h"
#endif
template <uint32_t ALIGN, typename T = uint32_t>
inline __aicore__ T RoundUp(const T val)
{
static_assert(ALIGN != 0, "align must not be zero");
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
T align = ALIGN;
if (val + align - 1 < val) {
return val;
}
return (val + align - 1) / align * align;
}
template <typename T>
inline __aicore__ T RoundUp(const T val, const T align)
{
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
if (align == 0 || val + align - 1 < val) {
return val;
}
return (val + align - 1) / align * align;
}
template <uint32_t DIVISOR, typename T = uint32_t>
inline __aicore__ T CeilDiv(const T dividend)
{
static_assert(DIVISOR != 0, "align must not be zero");
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
T divisor = DIVISOR;
if (dividend + divisor - 1 < dividend) {
return dividend;
}
return (dividend + divisor - 1) / divisor;
}
template <typename T>
constexpr T T_MAX = std::numeric_limits<T>::max();
template <typename T>
inline __aicore__ T CeilDiv(const T dividend, const T divisor)
{
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
if (divisor == 0 || dividend + divisor - 1 < dividend) {
return T_MAX<T>;
}
return (dividend + divisor - 1) / divisor;
}
template <typename T>
__aicore__ inline T Min(const T lhs, const T rhs)
{
return lhs < rhs ? lhs : rhs;
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint32_t BlockSize()
{
return 32 / sizeof(Dtype);
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize()
{
return 512 / sizeof(Dtype);
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num)
{
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>() * BlockSize<Dtype>();
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num)
{
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>();
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num)
{
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>() * MatrixSize<Dtype>();
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num)
{
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>();
}
template <typename Dtype>
__aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize()
{
return 32 * 1024 / sizeof(Dtype);
}
#endif

View File

@ -0,0 +1,36 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_HARDWARE_H
#define INCLUDE_HARDWARE_H
enum class ArchType { ASCEND_V220, ASCEND_V200, ASCEND_M200 };
template <ArchType ArchTag>
struct HardwareInfo {
static uint32_t const l2BW = 5;
static uint32_t const hbmBW = 1;
static uint32_t const supportMix = 0;
static uint32_t const l1Size = 512 * 1024;
static uint32_t const l0ASize = 64 * 1024;
static uint32_t const l0BSize = 64 * 1024;
static uint32_t const l0CSize = 128 * 1024;
static uint32_t const l2Size = 192 * 1024 * 1024;
static uint32_t const biasSize = 1024;
static uint32_t const fixBufSize = 7 * 1024;
static uint32_t const ubSize = 192 * 1024;
static uint32_t const fractalSize = 512;
static uint32_t const l1l0BlockSize = 32;
static uint32_t const btBlockSize = 64;
static uint32_t const fbBlockSize = 128;
};
#endif

View File

@ -0,0 +1,92 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_ITERTOR_H
#define INCLUDE_ITERTOR_H
#include "common_func.h"
#include "hardware.h"
#include "kernel_operator.h"
#include "layout.h"
#include "mem.h"
/////////////////////////////////////////////////////
// gm_to_l1
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType, DataFormat FormatInGM, DataFormat FormatInL1>
struct gm_to_l1 {
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor, AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
uint32_t dTileCeil, uint32_t dVal) {};
};
/////////////////////////////////////////////////////
// l1_to_l0_a
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
struct l1_to_l0_a {
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil, uint32_t kPartCeil, uint32_t mSrcStride, uint32_t kSrcStride,
uint32_t mDstStride, uint32_t kDstStride) {};
};
/////////////////////////////////////////////////////
// l1_to_l0_b
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
struct l1_to_l0_b {
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil, uint32_t kPartCeil, uint32_t nSrcStride, uint32_t kSrcStride,
uint32_t nDstStride, uint32_t kDstStride) {};
};
/////////////////////////////////////////////////////
// l0c_to_gm
/////////////////////////////////////////////////////
template <ArchType ArchTag, DataFormat OutFormatType, typename OutDataType, typename L0CDataType>
struct l0c_to_gm {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<OutDataType> gmTensor, AscendC::LocalTensor<L0CDataType> l0cTensor,
uint32_t mTileActual, uint32_t nTileActual, uint32_t mTileCeil, uint32_t nActual,
uint8_t unitFlag = 0) {};
};
/////////////////////////////////////////////////////
// l0c_to_l1
/////////////////////////////////////////////////////
template <ArchType ArchTag, DataFormat LayoutOut, typename ElementOut, typename ElementIn>
struct l0c_to_l1 {
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor, AscendC::LocalTensor<ElementIn> l0cTensor,
AscendC::LocalTensor<uint64_t> deqTensor, uint32_t mTileActual, uint32_t nTileActual,
uint32_t mTileCeil, uint32_t nActual) {};
};
template <ArchType ArchTag, typename DataType>
struct l1_to_bt {
__aicore__ l1_to_bt(uint64_t dst, const AscendC::LocalTensor<DataType> &src, uint16_t convControl, uint16_t nBurst,
uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) {};
};
template <ArchType ArchTag, typename DataType>
struct l1_to_fb {
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst, AscendC::LocalTensor<DataType> &src, uint16_t burstNum,
uint16_t burstLen, uint16_t srcGap, uint16_t dstGap) {};
};
#include "iterators/gm_to_l1_iterator.inc"
#include "iterators/gm_to_ub_iterator.inc"
#include "iterators/l0c_to_gm_iterator.inc"
#include "iterators/l0c_to_l1_iterator.inc"
#include "iterators/l0c_to_ub_iterator.inc"
#include "iterators/l1_to_bt_iterator.inc"
#include "iterators/l1_to_fb_iterator.inc"
#include "iterators/l1_to_l0_iterator.inc"
#include "iterators/l1_to_ub_iterator.inc"
#endif

View File

@ -0,0 +1,162 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
// Partial specialization for V220, ND_in, ND_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ND> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
AscendC::DataCopy(l1Tensor, // dst
gmTensor, // src
AscendC::DataCopyParams(1, // nBurst
CeilDiv<BLOCK_SIZE>(nTileActual * dTileActual), // lenBurst
0, // srcGap
0)); // dstGap
};
};
// Partial specialization for NZ_in, NZ_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
uint64_t srcStride = nVal - nTileCeil;
if (srcStride < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor, // dst
gmTensor, // src
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
nTileCeil, // lenBurst
srcStride, // srcGap
0)); // dstGap
} else {
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; i++) {
uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE;
uint64_t srcOffset = i * nVal * BLOCK_SIZE;
AscendC::DataCopy(l1Tensor[dstOffset], // dst
gmTensor[srcOffset], // src
AscendC::DataCopyParams(1, // nBurst
nTileCeil, // lenBurst
0, // srcGap
0)); // dstGap
}
}
};
};
// Partial specialization for V220, ND_in, ND_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
if (dVal < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
nTileActual, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
dVal, // srcDValue
nTileCeil, // dstNzC0Stride
1, // dstNzNStride
0)); // dstNzMatrixStride, unused
} else {
for (uint32_t i = 0; i < nTileActual; i++) {
AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE],
gmTensor[i * dVal],
AscendC::Nd2NzParams(1, // ndNum
1, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
0, // srcDValue
nTileCeil, // dstNzC0Stride
0, // dstNzNStride
0)); // dstNzMatrixStride, unused
}
}
};
};
// Partial specialization for V220, ND_in, NZ_out
template <ArchType ArchTag, typename DataType>
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ZN> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t STRIDE_LIMIT = 65536;
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::GlobalTensor<DataType> gmTensor,
uint32_t nTileActual,
uint32_t nTileCeil,
uint32_t nVal,
uint32_t dTileActual,
uint32_t dTileCeil,
uint32_t dVal)
{
if (dVal < STRIDE_LIMIT) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
nTileActual, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
dVal, // srcDValue
nTileCeil, // dstNzC0Stride
1, // dstNzNStride
0)); // dstNzMatrixStride, unused
} else {
for (uint32_t i = 0; i < nTileActual; ++i) {
AscendC::DataCopy(l1Tensor,
gmTensor,
AscendC::Nd2NzParams(1, // ndNum
1, // nValue
dTileActual, // dValue
0, // srcNdMatrixStride, unused
0, // srcDValue
nTileCeil, // dstNzC0Stride
0, // dstNzNStride
0)); // dstNzMatrixStride, unused
}
}
};
};

View File

@ -0,0 +1,89 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
template <ArchType ArchTag, typename DType> struct gm_to_ub {
__aicore__ inline gm_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DType> struct gm_to_ub_align {
__aicore__ inline gm_to_ub_align(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
{
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0),
AscendC::DataCopyPadExtParams<DType>(false, leftPaddingNum, rightPaddingNum, 0));
};
};
template <ArchType ArchTag, typename DType> struct ub_to_ub {
__aicore__ inline ub_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DataType, DataFormat InDataFormat = DataFormat::ND,
DataFormat OutDataFormat = DataFormat::ND>
struct ub_to_gm {
__aicore__ inline ub_to_gm(AscendC::GlobalTensor<DataType> dstTensor, AscendC::LocalTensor<DataType> srcTensor,
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
{
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
template <ArchType ArchTag, typename DataType> struct ub_to_gm<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ ub_to_gm(AscendC::GlobalTensor<DataType> gmTensor, AscendC::LocalTensor<DataType> ubTensor,
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
uint32_t dTileCeil, uint32_t dVal)
{
constexpr uint32_t STRIDE_LIMIT = 65536;
uint64_t dstStride = nVal - nTileCeil;
if (dstStride < STRIDE_LIMIT) {
AscendC::DataCopy(gmTensor, // dst
ubTensor, // src
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
nTileCeil, // lenBurst
0, // srcGap
dstStride)); // dstGap
} else {
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; ++i) {
uint64_t dstOffset = i * nVal * BLOCK_SIZE;
uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE;
AscendC::DataCopy(gmTensor[dstOffset], // dst
ubTensor[srcOffset], // src
AscendC::DataCopyParams(1, // nBurst
nTileCeil, // lenBurst
0, // srcGap
0)); // dstGap
}
}
};
};
template <ArchType ArchTag, typename DType> struct ub_to_gm_align {
__aicore__ inline ub_to_gm_align(AscendC::GlobalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
{
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0));
};
};

View File

@ -0,0 +1,228 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
constexpr uint32_t BLOCK_NUM = 16;
constexpr uint32_t BLOCK_SIZE_INT8 = 32;
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, float> {
/**
* @brief Copy data from L0C buffer to global memory, partial specialized for
*
* @param gmTensor the destination tensor on global memory, which is stored in ND format.
* @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format.
* @param mTileActual the m-direction size of the matrix in L0C buffer.
* @param nTileActual the n-direction size of the matrix in L0C buffer.
* @param srcStride the source stride between the adjacent fractal matrix along n-direction in unit of C0_SIZE.
* @param dstStride the leading dimension of the destination matrix in unit of element.
*/
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322F16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::F322F16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, int32_t> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::VDEQF16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<int32_t> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
#ifdef __DAV_C220_CUBE__
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, __bf16, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322BF16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::F322BF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
#endif
// Partial specialization ND, float
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, float, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<float> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<float, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::NoQuant};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::NZ, half, float> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
AscendC::LocalTensor<float> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::F322F16;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<half, float, AscendC::CFG_NZ>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<float> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride - (nTileActual * sizeof(half) / sizeof(float)));
intriParams.quantParams = {QuantMode_t::F322F16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};
template <>
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, int32_t, int32_t> {
__aicore__ l0c_to_gm(AscendC::GlobalTensor<int32_t> gmTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t srcStride,
uint32_t dstStride,
uint8_t unitFlag = 0)
{
#ifdef __DAV_C220_CUBE__
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
mTileActual, // mSize
srcStride, // srcStride
dstStride, // dstStride
false); // enRelu
intriParams.quantPre = QuantMode_t::NoQuant;
intriParams.unitFlag = unitFlag;
AscendC::Fixpipe<int32_t, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
#else
AscendC::FixpipeParams<int32_t> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
0,
dstStride);
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
#endif
};
};

View File

@ -0,0 +1,42 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l0c_to_l1
/////////////////////////////////////////////////////
// Partial specialization ZN, half, int32_t
template <ArchType ArchTag>
struct l0c_to_l1<ArchTag, DataFormat::ZN, half, int32_t> {
using ElementOut = half;
using ElementIn = int32_t;
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor,
AscendC::LocalTensor<ElementIn> l0cTensor,
AscendC::LocalTensor<uint64_t> deqTensor,
uint32_t mTileActual,
uint32_t nTileActual,
uint32_t mTileCeil,
uint32_t nActual)
{
constexpr uint32_t BLOCK_NUM = 16;
constexpr uint32_t BLOCK_SIZE = 32;
AscendC::FixpipeParams<ElementIn> intriParams(
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE),
0,
mTileCeil - static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) *
sizeof(ElementOut) / sizeof(ElementIn));
intriParams.nz2ndParams = {false, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
intriParams.quantParams = {QuantMode_t::VDEQF16};
AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams);
};
};

View File

@ -0,0 +1,71 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l0c_to_ub
/////////////////////////////////////////////////////
// Partial specialization ZN, half, int32_t
template <ArchType ArchTag, typename ElementIn, typename ElementOut, bool MatrixMode = true>
struct l0c_to_ub {
__aicore__ l0c_to_ub(AscendC::LocalTensor<ElementOut> ubTensor,
AscendC::LocalTensor<ElementIn> l0cTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
constexpr auto mode =
MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR;
AscendC::DataCopy(ubTensor,
l0cTensor,
AscendC::DataCopyParams(nBurst, // count
lenBurst, // len
srcStride, // srcStrideIn
dstStride), // dstStrideIn
AscendC::DataCopyEnhancedParams(mode, // blockModeIn
AscendC::DeqScale::DEQ_NONE, // deqScaleIn
0, // deqValueIn
0, // sidStoreModeIn
false, // isReluIn
pad_t::PAD_NONE, // padModeIn
0) // padValueIn
);
};
};
template <ArchType ArchTag>
struct l0c_to_ub<ArchTag, int32_t, half> {
__aicore__ l0c_to_ub(AscendC::LocalTensor<half> ubTensor,
AscendC::LocalTensor<int32_t> l0cTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(ubTensor,
l0cTensor,
AscendC::DataCopyParams(nBurst, // count
lenBurst, // len
srcStride, // srcStrideIn
dstStride), // dstStrideIn
AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn
AscendC::DeqScale::VDEQ16, // deqScaleIn
0, // deqValueIn
0, // sidStoreModeIn
false, // isReluIn
pad_t::PAD_NONE, // padModeIn
0) // padValueIn
);
};
};

View File

@ -0,0 +1,39 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_bt
/////////////////////////////////////////////////////
// Partial specialization for V220
template <typename DataType>
struct l1_to_bt<ArchType::ASCEND_V220, DataType> {
__aicore__ l1_to_bt(uint64_t dst,
const AscendC::LocalTensor<DataType> &src,
uint16_t convControl,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcGap,
uint16_t dstGap)
{
AscendC::LocalTensor<DataType> dstTensor;
dstTensor.InitBuffer(dst, nBurst * lenBurst);
dstTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
AscendC::DataCopy(dstTensor,
src,
AscendC::DataCopyParams(nBurst, // nBurst
lenBurst, // lenBurst
srcGap, // srcGap
dstGap)); // dstGap
}
};

View File

@ -0,0 +1,36 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_fb
/////////////////////////////////////////////////////
// Partial specialization for V220
template <typename DataType>
struct l1_to_fb<ArchType::ASCEND_V220, DataType> {
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst,
AscendC::LocalTensor<DataType> &src,
uint16_t burstNum,
uint16_t burstLen,
uint16_t srcGap,
uint16_t dstGap)
{
dst.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2PIPE2GM);
AscendC::DataCopy(dst,
src,
AscendC::DataCopyParams(burstNum, // nBurst
burstLen, // lenBurst
srcGap, // srcGap
dstGap)); // dstGap);
}
};

View File

@ -0,0 +1,310 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_l0_a
/////////////////////////////////////////////////////
// Partial specialization for vector
template <ArchType ArchTag, typename DataType, bool IsTransPose>
struct l1_to_l0_a<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(l0Tensor,
l1Tensor,
AscendC::LoadData2dParams(0, // baseIdx
kPartCeil, // repeat
kSrcStride, // srcStride
0, // sid
kDstStride, // dstStride
IsTransPose, // transpose
0)); // addrCalMode
};
};
// Partial specialization for no transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) {
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], // dst
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // src
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(kPartCeil / BLOCK_SIZE), // repeat
kSrcStride, // srcStride
0, // sid
kDstStride - 1, // dstStride
false, // transpose
0)); // addrCalMode
}
};
};
// Partial specialization for transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, true, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) {
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE],
l1Tensor[i * mSrcStride * FRACTAL_SIZE],
AscendC::LoadData2dParams(0,
static_cast<uint16_t>(kPartCeil / BLOCK_NUM_PER_FRACTAL),
kSrcStride,
0,
kDstStride - 1,
true,
0));
}
};
};
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::NZ, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchTag>;
// 16 * 32
static constexpr uint32_t ROW_BLOCK_SIZE = 16;
static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) {
AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil],
l1Tensor[i * FRACTAL_SIZE],
AscendC::LoadData2dParams(0,
static_cast<uint16_t>(kPartCeil / COL_BLOCK_SIZE),
mTileCeil / ROW_BLOCK_SIZE,
0,
0,
false,
0));
}
};
};
template <>
struct l1_to_l0_a<ArchType::ASCEND_V220, int8_t, true, DataFormat::ZN, DataFormat::ZZ> {
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 512
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; // 16
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
__aicore__ l1_to_l0_a(AscendC::LocalTensor<int8_t> l0Tensor,
AscendC::LocalTensor<int8_t> l1Tensor,
uint32_t mTileCeil,
uint32_t kPartCeil,
uint32_t mSrcStride,
uint32_t kSrcStride,
uint32_t mDstStride,
uint32_t kDstStride)
{
for (uint64_t i = 0; i < mTileCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
AscendC::LoadDataWithTranspose(
l0Tensor[i * mDstStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // dstLocalTensor
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // srcLocalTensor
AscendC::LoadData2dTransposeParams(0, // baseIdx
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(kPartCeil)), // repeat
kSrcStride, // srcStride
0, // dstGap
mDstStride - 1)); // dstFracGap
}
}
};
/////////////////////////////////////////////////////
// l1_to_l0_b
/////////////////////////////////////////////////////
// Partial specialization for vector
template <ArchType ArchTag, typename DataType, bool IsTransPose>
struct l1_to_l0_b<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(
l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0));
};
};
template <ArchType ArchTag>
struct l1_to_l0_b<ArchTag, int8_t, true, DataFormat::NZ, DataFormat::ZN> {
using HardwareParams = HardwareInfo<ArchTag>;
using DataType = int8_t;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) {
AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE],
l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE],
AscendC::LoadData2dTransposeParams(0, // startIndexIn
kPartCeil / BLOCK_SIZE, // repeatTimesIn
nTileCeil / BLOCK_SIZE, // srcStrideIn
1, // dstGapIn
0, // dstfracGapIn
0) // addrModeIn
);
}
};
};
// Partial specialization for no transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_b<ArchTag, DataType, false, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) {
AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE],
l1Tensor[i * kSrcStride * FRACTAL_SIZE],
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(nTileCeil / BLOCK_SIZE), // repeat
nSrcStride, // srcStride
0, // sid
nDstStride - 1, // dstStride
true, // transpose
0)); // addrCalMode
}
};
};
// Partial specialization for transpose, not vector
template <ArchType ArchTag, typename DataType>
struct l1_to_l0_b<ArchTag, DataType, true, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchTag>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
AscendC::LoadData(
l0Tensor,
l1Tensor,
AscendC::LoadData2dParams(0, // baseIdx
static_cast<uint16_t>(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat
1, // srcStride
0, // sid
0, // dstStride
false, // transpose
0)); // addr_cal_mode_t
};
};
template <>
struct l1_to_l0_b<ArchType::ASCEND_V220, int8_t, false, DataFormat::ZN, DataFormat::NZ> {
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 16
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
__aicore__ l1_to_l0_b(AscendC::LocalTensor<int8_t> l0Tensor,
AscendC::LocalTensor<int8_t> l1Tensor,
uint32_t nTileCeil,
uint32_t kPartCeil,
uint32_t nSrcStride,
uint32_t kSrcStride,
uint32_t nDstStride,
uint32_t kDstStride)
{
for (uint64_t i = 0; i < kPartCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
AscendC::LoadDataWithTranspose(
l0Tensor[i * kDstStride * FRACTAL_SIZE], // dstLocalTensor
l1Tensor[i * kSrcStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // srcLocalTensor
AscendC::LoadData2dTransposeParams(0, // baseIdx
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(nTileCeil)), // repeat
nSrcStride / NUM_FRACTAL_PER_ITER, // srcStride
1, // dstGap
0)); // dstFracGap
}
};
};

View File

@ -0,0 +1,44 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "../iterator.h"
/////////////////////////////////////////////////////
// l1_to_ub
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType>
struct l1_to_ub {
__aicore__ l1_to_ub(AscendC::LocalTensor<DataType> ubTensor,
AscendC::LocalTensor<DataType> l1Tensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};
/////////////////////////////////////////////////////
// ub_to_l1
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType>
struct ub_to_l1 {
__aicore__ ub_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
AscendC::LocalTensor<DataType> ubTensor,
uint16_t nBurst,
uint16_t lenBurst,
uint16_t srcStride,
uint16_t dstStride)
{
AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
};
};

View File

@ -0,0 +1,395 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
#define ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
#include "kernel_operator.h"
using AscendC::HardEvent;
__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y)
{
return y == 0 ? 0 : ((x + y - 1) / y);
}
__aicore__ inline uint32_t RoundUp(uint32_t x, uint32_t y = 16)
{
return (x + y - 1) / y * y;
}
__aicore__ inline uint32_t Min(uint32_t x, uint32_t y)
{
return x < y ? x : y;
}
__aicore__ inline uint32_t Max(uint32_t x, uint32_t y)
{
return x > y ? x : y;
}
template <typename T, typename Q>
__aicore__ inline void CopyIn(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
{
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
DataCopy(local, gm[offset], count);
queue.EnQue(local);
}
template <typename T, typename Q>
__aicore__ inline void CopyOut(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
{
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
DataCopy(gm[offset], local, count);
queue.FreeTensor(local);
}
template <typename T>
__aicore__ inline void CastFrom16To32(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &in,
uint32_t count)
{
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void CastFrom32To16(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
uint32_t count)
{
if constexpr (AscendC::IsSameType<T, half>::value) {
Cast(out, in, AscendC::RoundMode::CAST_NONE,
count); // 310p cast fp32->half 只能用CAST_NONE这里拉齐310p和910b
} else { // bf16
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
}
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void CastFromF16ToI8(const AscendC::LocalTensor<int8_t> &out, const AscendC::LocalTensor<half> &in,
half quantMin, uint32_t count)
{
Maxs(in, in, quantMin, count);
AscendC::PipeBarrier<PIPE_V>();
Mins(in, in, (half)127, count); // 127: limit
AscendC::PipeBarrier<PIPE_V>();
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
#else
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
#endif
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T, typename Q>
__aicore__ inline void CopyInAndCastF32(const AscendC::LocalTensor<float> &out, const AscendC::GlobalTensor<T> &gm,
Q &queue, uint64_t offset, uint32_t count)
{
CopyIn(gm, queue, offset, count);
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
Cast(out, local, AscendC::RoundMode::CAST_NONE, count);
queue.FreeTensor(local);
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T, typename Q>
__aicore__ inline void Cast16AndCopyOut(const AscendC::LocalTensor<float> &in, const AscendC::GlobalTensor<T> &gm,
Q &queue, uint64_t offset, uint32_t count)
{
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
CastFrom32To16(local, in, count);
queue.EnQue(local);
CopyOut(gm, queue, offset, count);
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline T ComputeSum(const AscendC::LocalTensor<T> &in, const AscendC::LocalTensor<T> &tmp,
const AscendC::LocalTensor<T> &workLocal, uint32_t count)
{
#if __CCE_AICORE__ == 100
float sum = 0;
int64_t elementNumPerRep = AscendC::ONE_REPEAT_BYTE_SIZE / sizeof(T);
AscendC::LocalTensor<T> src = in;
while (count > elementNumPerRep) {
int64_t repeatTimes = count / elementNumPerRep;
int64_t tailCount = count % elementNumPerRep;
int64_t bodyCount = repeatTimes * elementNumPerRep;
if (repeatTimes > 0) {
AscendC::AscendCUtils::SetMask<T>(elementNumPerRep);
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)src.GetPhyAddr(), repeatTimes, 1, 1, 8);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0); // PipeBarrier(PIPE_V)?
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
}
if (tailCount != 0) {
AscendC::AscendCUtils::SetMask<T>(tailCount);
vcadd((__ubuf__ T *)tmp[bodyCount].GetPhyAddr(), (__ubuf__ T *)src[bodyCount].GetPhyAddr(), 1, 1, 1, 8);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
sum += tmp.GetValue(bodyCount);
}
count = repeatTimes;
src = tmp;
}
if (count > 1) {
AscendC::AscendCUtils::SetMask<T>(count);
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)tmp.GetPhyAddr(), 1, 1, 1, 8);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
}
sum += tmp.GetValue(0);
return sum;
#else
ReduceSum(tmp, in, workLocal, count);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
return tmp.GetValue(0);
#endif
}
__aicore__ inline float ComputeSliceSquareSum(const AscendC::LocalTensor<float> &in,
const AscendC::LocalTensor<float> &tmp,
const AscendC::LocalTensor<float> &workLocal, uint32_t count)
{
Mul(tmp, in, in, count);
AscendC::PipeBarrier<PIPE_V>();
return ComputeSum(tmp, tmp, workLocal, count);
}
template <typename T>
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
uint32_t precisionMode, uint32_t gemmaMode,
const AscendC::LocalTensor<float> &tmp)
{
float value = 1.0;
Duplicate(tmp, rms, count);
AscendC::PipeBarrier<PIPE_V>();
Div(tmp, in, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
if (precisionMode == 0) {
CastFrom16To32(in, gamma, count);
AscendC::PipeBarrier<PIPE_V>();
if (gemmaMode == 1) {
Adds(in, in, value, count);
AscendC::PipeBarrier<PIPE_V>();
}
Mul(in, in, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(out, in, count);
return;
}
if constexpr (std::is_same<T, half>::value) {
CastFrom32To16(out, tmp, count);
Mul(out, out, gamma, count);
AscendC::PipeBarrier<PIPE_V>();
}
}
template <typename T, uint32_t gemmaMode>
__aicore__ inline void CastGAndIsGemmaMode(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &gamma,
uint32_t count)
{
Cast(out, gamma, AscendC::RoundMode::CAST_NONE, count);
AscendC::PipeBarrier<PIPE_V>();
float value = 1.0;
if constexpr (gemmaMode == 1) {
Adds(out, out, value, count);
AscendC::PipeBarrier<PIPE_V>();
}
}
template <typename T, uint32_t precisionMode>
__aicore__ inline void ComputeRmsNormFast(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
const AscendC::LocalTensor<float> &tmp,
const AscendC::LocalTensor<float> &fp32_g)
{
float value = 1.0;
Duplicate(tmp, rms, count);
AscendC::PipeBarrier<PIPE_V>();
Div(tmp, in, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
if constexpr (precisionMode == 0) {
Mul(in, fp32_g, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(out, in, count);
return;
}
if constexpr (std::is_same<T, half>::value) {
CastFrom32To16(out, tmp, count);
Mul(out, out, gamma, count);
AscendC::PipeBarrier<PIPE_V>();
}
}
template <bool WITH_BETA = true>
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
float rms, const AscendC::LocalTensor<half> &gamma,
const AscendC::LocalTensor<half> &beta, const AscendC::LocalTensor<float> &tmp,
uint32_t count)
{
Duplicate(tmp, rms, count);
AscendC::PipeBarrier<PIPE_V>();
Div(out, in, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom16To32(tmp, gamma, count);
Mul(out, out, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
if constexpr (WITH_BETA) {
CastFrom16To32(tmp, beta, count);
Add(out, out, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
}
}
template <typename T>
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
float reciprocal_of_rms, const AscendC::LocalTensor<T> &gamma,
const AscendC::LocalTensor<float> &tmp, const AscendC::LocalTensor<T> &res_out,
uint32_t count)
{
Duplicate(tmp, reciprocal_of_rms, count);
AscendC::PipeBarrier<PIPE_V>();
Mul(out, in, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom16To32(tmp, gamma, count);
Mul(out, out, tmp, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(res_out, out, count);
}
template <typename T>
__aicore__ inline void ComputeResidualAdd(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in,
const AscendC::LocalTensor<T> &resIn, uint32_t count)
{
Add(out, in, resIn, count);
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void ComputeMean(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in, T aveNum,
uint32_t count)
{
Duplicate(out, aveNum, count);
AscendC::PipeBarrier<PIPE_V>();
Mul(out, in, out, count);
AscendC::PipeBarrier<PIPE_V>();
T sum = ComputeSum(out, out, out, count);
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
Duplicate(out, sum, count);
AscendC::PipeBarrier<PIPE_V>();
}
template <typename T>
__aicore__ inline void ComputeLayerNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
const AscendC::LocalTensor<float> &mean, float eps, float aveNum,
const AscendC::LocalTensor<T> &gamma, const AscendC::LocalTensor<T> &beta,
uint32_t count)
{
Sub(in, in, mean, count);
AscendC::PipeBarrier<PIPE_V>();
Mul(out, in, in, count);
AscendC::PipeBarrier<PIPE_V>();
Muls(out, out, aveNum, count);
AscendC::PipeBarrier<PIPE_V>();
ReduceSum(out, out, out, count);
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
float var = out.GetValue(0);
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
Duplicate(out, var, count);
AscendC::PipeBarrier<PIPE_V>();
Adds(out, out, eps, count);
AscendC::PipeBarrier<PIPE_V>();
Sqrt(out, out, count);
AscendC::PipeBarrier<PIPE_V>();
Div(out, in, out, count);
AscendC::PipeBarrier<PIPE_V>();
Cast(in, gamma, AscendC::RoundMode::CAST_NONE, count);
AscendC::PipeBarrier<PIPE_V>();
Mul(out, out, in, count);
AscendC::PipeBarrier<PIPE_V>();
Cast(in, beta, AscendC::RoundMode::CAST_NONE, count);
AscendC::PipeBarrier<PIPE_V>();
Add(out, out, in, count);
AscendC::PipeBarrier<PIPE_V>();
}
__aicore__ inline void ComputeFp16ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
const AscendC::LocalTensor<half> &in, const AscendC::LocalTensor<half> &tmp,
half scale, half offset, half quantMin, uint32_t count)
{
Muls(tmp, in, scale, count);
AscendC::PipeBarrier<PIPE_V>();
Adds(tmp, tmp, offset, count);
AscendC::PipeBarrier<PIPE_V>();
CastFromF16ToI8(out, tmp, quantMin, count);
}
__aicore__ inline void ComputeFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
const AscendC::LocalTensor<float> &in,
const AscendC::LocalTensor<half> &tmp, half scale, half offset,
half quantMin, uint32_t count)
{
CastFrom32To16(tmp, in, count);
AscendC::PipeBarrier<PIPE_V>();
ComputeFp16ToI8Quant(out, tmp, tmp, scale, offset, quantMin, count);
}
__aicore__ inline void ComputeHighPrecisionFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
const AscendC::LocalTensor<float> &in,
const AscendC::LocalTensor<half> &tmp, float scale,
float offset, half quantMin, uint32_t count)
{
Muls(in, in, scale, count);
AscendC::PipeBarrier<PIPE_V>();
Adds(in, in, offset, count);
AscendC::PipeBarrier<PIPE_V>();
CastFrom32To16(tmp, in, count);
CastFromF16ToI8(out, tmp, quantMin, count);
}
__aicore__ inline void CopyGmTilingToUb(__ubuf__ uint8_t *&tilingInUb, const __gm__ uint8_t *tilingInGm,
size_t tilingSize, AscendC::TPipe *pipe)
{
uint32_t roundTilingSize = RoundUp(tilingSize, 32);
AscendC::TBuf<AscendC::TPosition::VECCALC> tilingBuf;
AscendC::GlobalTensor<uint8_t> tilingGm;
tilingGm.SetGlobalBuffer((__gm__ uint8_t *)tilingInGm);
pipe->InitBuffer(tilingBuf, roundTilingSize);
AscendC::LocalTensor<uint8_t> tilingUb = tilingBuf.Get<uint8_t>();
AscendC::DataCopy(tilingUb, tilingGm, roundTilingSize);
tilingInUb = (__ubuf__ uint8_t *)tilingUb.GetPhyAddr();
}
template <typename T>
__aicore__ inline uint32_t GetReduceSumWorkLocalSize(uint32_t sliceSize)
{
uint32_t elementsPerBlock = 32 / sizeof(T);
uint32_t elementsPerRepeat = 256 / sizeof(T);
uint32_t firstMaxRepeat = sliceSize < elementsPerRepeat ? 1u : (sliceSize / elementsPerRepeat);
uint32_t iter1OutputCount = firstMaxRepeat;
uint32_t iter1AlignEnd = RoundUp(iter1OutputCount, elementsPerBlock);
return iter1AlignEnd;
}
#endif

View File

@ -0,0 +1,18 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_LAYOUT_H
#define INCLUDE_LAYOUT_H
enum class DataFormat { ND = 0, NZ, ZN, ZZ, NN, VECTOR };
#endif

View File

@ -0,0 +1,82 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_MEM_H
#define INCLUDE_MEM_H
#include "hardware.h"
#include "kernel_event.h"
#include "kernel_tensor.h"
enum class BufferType { ASCEND_UB, ASCEND_CB, ASCEND_L0A, ASCEND_L0B, ASCEND_L0C, ASCEND_MAX };
template <BufferType BufferType_>
__aicore__ constexpr AscendC::TPosition GetPosition()
{
if constexpr (BufferType_ == BufferType::ASCEND_UB) {
return AscendC::TPosition::VECIN;
} else if constexpr (BufferType_ == BufferType::ASCEND_CB) {
return AscendC::TPosition::A1;
} else if constexpr (BufferType_ == BufferType::ASCEND_L0A) {
return AscendC::TPosition::A2;
} else if constexpr (BufferType_ == BufferType::ASCEND_L0B) {
return AscendC::TPosition::B2;
} else if constexpr (BufferType_ == BufferType::ASCEND_L0C) {
return AscendC::TPosition::CO1;
}
return AscendC::TPosition::GM;
}
template <ArchType ArchTag>
struct AsdopsBuffer {
public:
__aicore__ AsdopsBuffer()
{
constexpr uint32_t bufferSize[(uint32_t)BufferType::ASCEND_MAX] = {
HardwareInfo<ArchTag>::ubSize, HardwareInfo<ArchTag>::l1Size, HardwareInfo<ArchTag>::l0ASize,
HardwareInfo<ArchTag>::l0BSize, HardwareInfo<ArchTag>::l0CSize};
#ifdef __DAV_C220_VEC__
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
#elif defined(__DAV_C220_CUBE__)
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
#else
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
#endif
};
template <BufferType BufferType_, typename DstDataType = half>
__aicore__ AscendC::LocalTensor<DstDataType> GetBuffer(const uint32_t offset) const
{
return tensor[(uint32_t)BufferType_][offset].template ReinterpretCast<DstDataType>();
}
public:
AscendC::LocalTensor<uint8_t> tensor[(uint32_t)BufferType::ASCEND_MAX];
};
#endif

View File

@ -0,0 +1,67 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_MMA_H
#define INCLUDE_MMA_H
#include "hardware.h"
#include "kernel_tensor.h"
template <ArchType ArchTag, typename ElementA, typename ElementB, typename AccDTypeC, bool IsTransposeA>
struct mmad {
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
};
// Partial specialization for V220, int8_t, not_vector_A, not TransposeA
template <ArchType ArchTag, typename AccDTypeC, typename ElementA, typename ElementB>
struct mmad<ArchTag, ElementA, ElementB, AccDTypeC, false> {
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
{
AscendC::Mmad(l0cTensor, // C
l0aTensor, // A
l0bTensor, // B
AscendC::MmadParams(mTileActual, // m
nTileActual, // n
kPartActual, // k
unitFlag, // unitFlag
false, // cmatrixSource
initC)); // cmatrixInitVal
};
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
{
AscendC::LocalTensor<AccDTypeC> biasTensor;
biasTensor.InitBuffer(biasBt, nTileActual);
biasTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
AscendC::Mmad(l0cTensor, // C
l0aTensor, // A
l0bTensor, // B
biasTensor, // bt
AscendC::MmadParams(mTileActual, // m
nTileActual, // n
kPartActual, // k
unitFlag, // unitFlag
true, // cmatrixSource
false)); // cmatrixInitVal
};
};
#endif

View File

@ -0,0 +1,38 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_SET_FPC_H
#define INCLUDE_SET_FPC_H
#include "hardware.h"
#include "kernel_tensor.h"
/////////////////////////////////////////////////////
// SetQuantPreAddr
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DataType>
struct SetQuantPreAddr {
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor) {};
};
template <typename DataType>
struct SetQuantPreAddr<ArchType::ASCEND_V220, DataType> {
static constexpr uint32_t QUANT_PRE_ADDR_MASK = 0xffff;
static constexpr uint32_t USELESS_BIT_NUM = 7;
static constexpr uint32_t QUANT_PRE_BIT_POS_IN_FPC = 8;
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor)
{
uint64_t quantPreAddr = (uint64_t)(__fbuf__ uint64_t *)quantPreTensor.GetPhyAddr();
AscendC::SetFixPipeConfigImpl(quantPreTensor);
};
};
#endif

View File

@ -0,0 +1,274 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_SIMD_H
#define INCLUDE_SIMD_H
#include "hardware.h"
#include "kernel_operator.h"
/////////////////////////////////////////////////////
// vcgadd
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void cgadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
{
AscendC::BlockReduceSum<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
}
/////////////////////////////////////////////////////
// vadd
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void add_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
{
AscendC::Add<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
dstRepeatStride, src0RepeatStride, src1RepeatStride));
}
/////////////////////////////////////////////////////
// vadds
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void adds_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, DType scalarValue,
uint8_t repeat, uint8_t dstBlockStride, uint8_t srcBlockStride, uint8_t dstRepeatStride,
uint8_t srcRepeatStride)
{
AscendC::Adds<DType, false>(
dst, src, scalarValue, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vcadd
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void cadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
{
AscendC::RepeatReduceSum<DType, false>(dst, src, repeat, 0, 0, srcBlockStride, dstRepeatStride, srcRepeatStride);
}
/////////////////////////////////////////////////////
// vbrcb
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void brcb_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint16_t dstBlockStride,
uint16_t dstRepeatStride, uint8_t repeat)
{
AscendC::Brcb(dst, src, repeat, AscendC::BrcbRepeatParams(dstBlockStride, dstRepeatStride));
}
/////////////////////////////////////////////////////
// vcmax
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType, AscendC::ReduceOrder OrderType>
__aicore__ inline void cmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
{
#if defined(__DAV_C220_VEC__)
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
srcRepeatStride, OrderType);
#else
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
srcRepeatStride);
#endif
}
/////////////////////////////////////////////////////
// vconv
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
__aicore__ inline void conv_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
uint16_t srcRepeatStride)
{
if constexpr (std::is_same<DTypeIn, float>::value && std::is_same<DTypeOut, __bf16>::value) {
AscendC::Cast<DTypeOut, DTypeIn, false>(
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
} else {
AscendC::Cast<DTypeOut, DTypeIn, false>(
dst, src, AscendC::RoundMode::CAST_NONE, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
}
/////////////////////////////////////////////////////
// vconv_f322bf16r
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
__aicore__ inline void convr_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
uint16_t srcRepeatStride)
{
AscendC::Cast<DTypeOut, DTypeIn, false>(
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vdiv
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void div_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
{
AscendC::Div<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
dstRepeatStride, src0RepeatStride, src1RepeatStride));
}
/////////////////////////////////////////////////////
// vexp
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void exp_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
uint16_t srcRepeatStride)
{
AscendC::Exp<DType, false>(
dst, src, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vmax
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void max_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
{
AscendC::Max<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
dstRepeatStride, src0RepeatStride, src1RepeatStride));
}
/////////////////////////////////////////////////////
// vmul
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void mul_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
{
AscendC::Mul<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
dstRepeatStride, src0RepeatStride, src1RepeatStride));
}
/////////////////////////////////////////////////////
// vmuls
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void muls_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
{
AscendC::Muls<DType, false>(
dst, src0, src1, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vsub
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void sub_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
{
AscendC::Sub<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
dstRepeatStride, src0RepeatStride, src1RepeatStride));
}
/////////////////////////////////////////////////////
// vmaxs
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void maxs_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
{
AscendC::Maxs<DType, false>(
dst, src0, src1, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vmins
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void mins_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
{
AscendC::Mins<DType, false>(
dst, src0, src1, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vsqrt
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void sqrt_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
uint16_t srcRepeatStride)
{
AscendC::Sqrt<DType, false>(
dst, src, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vln
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void ln_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
uint16_t srcRepeatStride)
{
AscendC::Ln<DType, false>(
dst, src, (uint64_t)0, repeat,
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
}
/////////////////////////////////////////////////////
// vtranspose
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void tranpose_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src)
{
AscendC::Transpose(dst, src);
}
/////////////////////////////////////////////////////
// vcgmax
/////////////////////////////////////////////////////
template <ArchType ArchTag, typename DType>
__aicore__ inline void cgmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
{
AscendC::BlockReduceMax<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
}
#endif

View File

@ -0,0 +1,69 @@
/* Adapted from
* https://gitee.com/ascend/ascend-transformer-boost.git
*
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This file is a part of the CANN Open Software.
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef INCLUDE_UTILS_H
#define INCLUDE_UTILS_H
template <typename IN_DTYPE>
__aicore__ inline void CreateCaMatrix(const AscendC::LocalTensor<IN_DTYPE> &dst, const uint16_t repeats,
const uint16_t blockNum, const uint16_t dstGap, const IN_DTYPE initValue)
{
AscendC::InitConstValue<IN_DTYPE>(dst,
AscendC::InitConstValueParams<IN_DTYPE>(repeats, blockNum, dstGap, initValue));
}
__aicore__ inline void SetFftsBaseAddr(uint64_t config)
{
AscendC::SetSyncBaseAddr(config);
}
template <typename IN_DTYPE>
__aicore__ inline void SetPadding(IN_DTYPE padValue)
{
AscendC::SetLoadDataPaddingValue<IN_DTYPE>(padValue);
}
__aicore__ inline void SetAtomicnone()
{
AscendC::SetAtomicNone();
}
__aicore__ inline void SetMasknorm()
{
#if __CCE_AICORE__ == 100
return;
#endif
AscendC::SetMaskNorm();
}
__aicore__ inline void SetNdpara(uint16_t ndNum, uint16_t srcNdStride, uint16_t dstNdStride)
{
AscendC::SetFixpipeNz2ndFlag(ndNum, srcNdStride, dstNdStride);
}
template <typename IN_DTYPE>
__aicore__ inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)
{
AscendC::SetVectorMask<IN_DTYPE>(maskHigh, maskLow);
}
__aicore__ inline int64_t GetSubBlockidx()
{
return AscendC::GetSubBlockIdx();
}
__aicore__ inline void WaitFlagDev(uint16_t flagId)
{
AscendC::WaitEvent(flagId);
}
template <pipe_t pipe, uint8_t mode>
__aicore__ inline void FftsCrossCoreSync(uint16_t flagId)
{
AscendC::CrossCoreSetFlag<mode, pipe>(flagId);
}
template <typename IN_DTYPE, bool setRelu = false>
__aicore__ inline void SetFpc(const AscendC::LocalTensor<IN_DTYPE> &preTensor, bool isUnitFlag = false)
{
AscendC::SetFixPipeConfig<IN_DTYPE, setRelu>(preTensor, isUnitFlag);
}
#endif

View File

@ -0,0 +1,114 @@
// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#ifndef __MLA_PREPROCESS_H__
#define __MLA_PREPROCESS_H__
// sync
constexpr int32_t QUANT1 = 1;
constexpr int32_t MM1 = 2;
constexpr int32_t MM1QUANT = 3;
constexpr int32_t RMSNORMQUANT2 = 4;
constexpr int32_t MM2 = 5;
constexpr int32_t MM2QUANT = 6;
constexpr int32_t BMM3 = 7;
constexpr int32_t BMM3SPLIT = 8;
constexpr int32_t MM2OUT = 9;
constexpr int32_t EINSUMOUT = 11;
constexpr int32_t EINSUMQUANT = 12;
// ropeConcat
constexpr uint32_t ELE_NUM_FP16 = 16; // nums of fp16 elements in one block
constexpr uint32_t ELE_NUM_FP32 = 8; // nums of fp32 elements in one block
constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // stride, 8 * 32 = 256
// rmsNormQuant
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
constexpr float ZERO = 0;
constexpr uint32_t BUF_FACTOR = 3; // 1(g) + 1(sqx) + 1(sum) = 3
constexpr uint32_t OFFSET_GAMMA = 0; // the offset of gamma is 0
constexpr uint32_t OFFSET_SQX = 1; // the offset of sqx is 1
constexpr uint32_t OFFSET_SUM = 2; // the offset of sum is 2
constexpr uint32_t OFFSET_WORKSPACE = 3; // the offset of workspace is 3
constexpr uint32_t REPEAT_TIME_256 = 256; // 128 default stride
constexpr uint32_t REPEAT_TIME_128 = 128; // 128 default stride
constexpr uint32_t REPEAT_TIME_64 = 64; // 64 default stride
constexpr uint8_t CACHE_MODE_KVCACHE = 0; // single input single output
constexpr uint8_t CACHE_MODE_KROPE_CTKV = 1; // double in and double out
constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format/quant int8
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
// pp matmul
constexpr uint32_t HIDDTEN_STATE = 7168;
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
constexpr uint32_t HALF_BLOCK_SIZE = 64;
constexpr uint32_t HALF_VECTOR_SIZE = 64;
constexpr uint32_t MM1_OUT_SIZE = 2112;
constexpr uint32_t SPLIT_SIZE_ONE = 576;
constexpr uint32_t SPLIT_SIZE_TWO = 1536;
constexpr uint32_t SPLIT_RMSNRORM_SIZE_ONE = 512;
constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64;
constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64;
constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128;
constexpr uint32_t MMSIZE1 = 128 * 192; // 24576
constexpr uint32_t MMSIZE2 = 64 * 128; // 8192
constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB
constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB
constexpr uint64_t BLOCK_SIZE_16 = 16;
constexpr uint64_t BLOCK_SIZE_32 = 32;
constexpr uint64_t CUBE_MATRIX_SIZE_512 = 16 * 32; // 16 * 23
constexpr uint64_t FB_BUFF_SIZE = 1024 * 7;
constexpr uint64_t SCALE_L1_LEN = 4096;
constexpr uint64_t BIAS_L1_LEN = 2048;
constexpr uint64_t CONST_0 = 0;
constexpr uint64_t CONST_4 = 4;
constexpr uint64_t CONST_8 = 8;
constexpr uint64_t CONST_32 = 32;
constexpr uint64_t CONST_64 = 64;
constexpr uint64_t CONST_128 = 128;
// ropeConcat
constexpr uint32_t ROPE_CONCAT_NUM_BUFFER = 2;
// rmsNormQuant
constexpr uint32_t OFFSET_ABS = 3; // the offset of abs is 3
constexpr uint32_t OFFSET_WORKSPACE_BF16 = 4; // the offset of workspace is 4
// sync bf16
constexpr int32_t AIC_MM1_START = 2;
constexpr int32_t AIC_MM3_START = 3;
constexpr int32_t AIC_MM2_START = 6;
constexpr int32_t MMAIC = 7;
constexpr int32_t MMAIV = 8;
constexpr uint32_t MAX_HW_SYNC_COUNTER = 5;
constexpr uint32_t SYNC_MODE = 2;
// TilingKey
constexpr uint32_t KEY_FP16_CACHEMODE_0_QUANTMODE_0 = 0;
constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
enum class QuantMode : int32_t {
PER_TENSOR_ASYMM_QUANT = 0,
PER_TOKEN_SYMM_QUANT,
PER_TOKEN_ASYMM_QUANT,
NO_QUANT,
};
#endif // __MLA_PREPROCESS_H__

View File

@ -0,0 +1,299 @@
// Adapted from
// https://gitee.com/ascend/ascend-transformer-boost
//
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This file is a part of the CANN Open Software.
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
// Please refer to the License for details. You may not use this file except in compliance with the License.
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
// See LICENSE in the root of the software repository for the full text of the License.
//
#include "kernel_operator.h"
#include "../../kernels/types.h"
#include "mla_preprocess_mix_fp16.hpp"
#include "mla_preprocess_mix_bf16.hpp"
#include "../op_host/tiling/mla_preprocess_tiling.h"
extern "C" __global__ __aicore__ void mla_preprocess(
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
{
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
PRELOAD(2);
#endif
SetAtomicnone();
SetMasknorm();
#ifdef __DAV_C220_CUBE__
SetPadding<uint64_t>((uint64_t)0);
SetNdpara(1, 0, 0);
#endif
MlaTilingData mlaTilingData;
__gm__ MlaTilingData *tilingData = reinterpret_cast<__gm__ MlaTilingData *>(tiling);
mlaTilingData.tilingKey = tilingData->tilingKey;
mlaTilingData.n = tilingData->n;
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
mlaTilingData.mm1.m = tilingData->mm1.m;
mlaTilingData.mm1.k = tilingData->mm1.k;
mlaTilingData.mm1.n = tilingData->mm1.n;
mlaTilingData.mm1.m0 = tilingData->mm1.m0;
mlaTilingData.mm1.k0 = tilingData->mm1.k0;
mlaTilingData.mm1.n0 = tilingData->mm1.n0;
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop;
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop;
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop;
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop;
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount;
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK;
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim;
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat;
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen;
mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch;
mlaTilingData.mm2.m = tilingData->mm2.m;
mlaTilingData.mm2.k = tilingData->mm2.k;
mlaTilingData.mm2.n = tilingData->mm2.n;
mlaTilingData.mm2.m0 = tilingData->mm2.m0;
mlaTilingData.mm2.k0 = tilingData->mm2.k0;
mlaTilingData.mm2.n0 = tilingData->mm2.n0;
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop;
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop;
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop;
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop;
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount;
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK;
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim;
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat;
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen;
mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch;
mlaTilingData.mm3.m = tilingData->mm3.m;
mlaTilingData.mm3.k = tilingData->mm3.k;
mlaTilingData.mm3.n = tilingData->mm3.n;
mlaTilingData.mm3.m0 = tilingData->mm3.m0;
mlaTilingData.mm3.k0 = tilingData->mm3.k0;
mlaTilingData.mm3.n0 = tilingData->mm3.n0;
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop;
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop;
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop;
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop;
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount;
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK;
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim;
mlaTilingData.perTaskNum = tilingData->perTaskNum;
mlaTilingData.resTaskNum = tilingData->resTaskNum;
mlaTilingData.numCore = tilingData->numCore;
mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1;
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1;
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2;
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2;
mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ;
mlaTilingData.headNumQ = tilingData->headNumQ;
mlaTilingData.headDim = tilingData->headDim;
mlaTilingData.concatSize = tilingData->concatSize;
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff;
mlaTilingData.ntokens = tilingData->ntokens;
mlaTilingData.realCore = tilingData->realCore;
mlaTilingData.nlCoreRun = tilingData->nlCoreRun;
mlaTilingData.lCoreRun = tilingData->lCoreRun;
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb;
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime;
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast;
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime;
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast;
mlaTilingData.esqFrontCore = tilingData->esqFrontCore;
mlaTilingData.esqTailCore = tilingData->esqTailCore;
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch;
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch;
mlaTilingData.esqHeadNum = tilingData->esqHeadNum;
mlaTilingData.esqColNum = tilingData->esqColNum;
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop;
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop;
mlaTilingData.esqHeadTail = tilingData->esqHeadTail;
mlaTilingData.esqColLoop = tilingData->esqColLoop;
mlaTilingData.esqColTail = tilingData->esqColTail;
mlaTilingData.s1Offset = tilingData->s1Offset;
mlaTilingData.s2Offset = tilingData->s2Offset;
mlaTilingData.s3Offset = tilingData->s3Offset;
mlaTilingData.s4Offset = tilingData->s4Offset;
mlaTilingData.s5Offset = tilingData->s5Offset;
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
GM_ADDR s4 = workspace + static_cast<uint64_t>(mlaTilingData.s4Offset);
GM_ADDR s5 = workspace + static_cast<uint64_t>(mlaTilingData.s5Offset);
switch (mlaTilingData.tilingKey) {
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
mlaTilingData, tiling);
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opFp16Cm0Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opFp16Cm0Qm0.ProcessVector();
}
break;
}
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
opFp16Cm1Qm0(mlaTilingData, tiling);
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3);
if ASCEND_IS_AIC {
opFp16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opFp16Cm1Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm0Qm0(mlaTilingData, tiling);
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm0Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm0Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm1Qm0(mlaTilingData, tiling);
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm1Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm1Qm0.ProcessVector();
}
break;
}
case KEY_BF16_CACHEMODE_3_QUANTMODE_0: {
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
QuantMode::PER_TENSOR_ASYMM_QUANT>
opBf16Cm3Qm0(mlaTilingData, tiling);
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
s1, s2, s3, s4, s5);
if ASCEND_IS_AIC {
opBf16Cm3Qm0.ProcessCube();
}
if ASCEND_IS_AIV {
opBf16Cm3Qm0.ProcessVector();
}
break;
}
default: {
break;
}
}
return;
}
namespace vllm_ascend {
extern void mla_preprocess_impl(
void* stream,
void* hidden_state,
void* gamma1,
void* beta1,
void* quant_scale1,
void* quant_offset1,
void* wdqkv,
void* bias1,
void* gamma2,
void* beta2,
void* quant_scale2,
void* quant_offset2,
void* gamma3,
void* sin1,
void* cos1,
void* sin2,
void* cos2,
void* keycache,
void* slot_mapping,
void* wuq,
void* bias2,
void* wuk,
void* descale1,
void* descale2,
void* ctkv_scale,
void* qnope_scale,
void* q,
void* keycache_out,
void* q2,
void* keycache_out2,
void* workspace,
void* tiling,
const uint32_t block_dim)
{
mla_preprocess<<<block_dim, nullptr, stream>>>(
hidden_state,
gamma1,
beta1,
quant_scale1,
quant_offset1,
wdqkv,
bias1,
gamma2,
beta2,
quant_scale2,
quant_offset2,
gamma3,
sin1,
cos1,
sin2,
cos2,
keycache,
slot_mapping,
wuq,
bias2,
wuk,
descale1,
descale2,
ctkv_scale,
qnope_scale,
q,
keycache_out,
q2,
keycache_out2,
workspace,
tiling);
}
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -124,4 +124,40 @@ namespace vllm_ascend {
uint32_t output_hidden_dim,
uint32_t slice_offset,
uint32_t output_full_dim);
extern void mla_preprocess_impl(
void* stream,
void* hidden_state,
void* gamma1,
void* beta1,
void* quant_scale1,
void* quant_offset1,
void* wdqkv,
void* bias1,
void* gamma2,
void* beta2,
void* quant_scale2,
void* quant_offset2,
void* gamma3,
void* sin1,
void* cos1,
void* sin2,
void* cos2,
void* keycache,
void* slot_mapping,
void* wuq,
void* bias2,
void* wuk,
void* descale1,
void* descale2,
void* ctkv_scale,
void* qnope_scale,
void* q,
void* keycache_out,
void* q2,
void* keycache_out2,
void* workspace,
void* tiling,
const uint32_t block_dim
);
}

View File

@ -23,6 +23,7 @@
#include "acl/acl.h"
#include "ops.h"
#include "utils.h"
#include "mla_preprocess/op_host/mla_preprocess.h"
namespace vllm_ascend {
@ -106,6 +107,83 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
return {query_dst, key_dst};
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
{
at::Tensor CtkvScale =
ctkv_scale.has_value()
? ctkv_scale.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
at::Tensor QnopeScale =
q_nope_scale.has_value()
? q_nope_scale.value()
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
hiddenState,
wuk,
cache_mode,
quant_mode
);
void *hidden_state_ptr = hiddenState.data_ptr();
void *gamma0_ptr = gamma0.data_ptr();
void *beta0_ptr = beta0.data_ptr();
void *quant_scale0_ptr = quant_scale0.data_ptr();
void *quant_offset0_ptr = quant_offset0.data_ptr();
void *wdqkv_ptr = wdqkv.data_ptr();
void *bias0_ptr = bias0.data_ptr();
void *gamma1_ptr = gamma1.data_ptr();
void *beta1_ptr = beta1.data_ptr();
void *quant_scale1_ptr = quant_scale1.data_ptr();
void *quant_offset1_ptr = quant_offset1.data_ptr();
void *gamma2_ptr = gamma2.data_ptr();
void *sin_ptr = sin.data_ptr();
void *cos_ptr = cos.data_ptr();
void *kv_cache_ptr = kv_cache.data_ptr();
void *slotmapping_ptr = slotmapping.data_ptr();
void *wuq_ptr = wuq.data_ptr();
void *bias1_ptr = bias1.data_ptr();
void *wuk_ptr = wuk.data_ptr();
void *descale0_ptr = descale0.data_ptr();
void *descale1_ptr = descale1.data_ptr();
void *ctkv_scale_ptr = CtkvScale.data_ptr();
void *qnope_scale_ptr = QnopeScale.data_ptr();
void *q_out0_ptr = q_out0.data_ptr();
void *kv_cache_out0_ptr = kv_cache_out0.data_ptr();
void *q_out1_ptr = q_out1.data_ptr();
void *kv_cache_out1_ptr = kv_cache_out1.data_ptr();
void *workspace_ptr = workspace_tensor.data_ptr();
void *tiling_ptr = tiling.data_ptr();
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
at_npu::native::OpCommand cmd;
cmd.Name("mla_preprocess");
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
tiling_ptr, block_dim]() -> int {
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
tiling_ptr, block_dim);
return 0;
});
cmd.Run();
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1);
}
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
at::Tensor &input,
const int64_t org_vocab_start_index,
@ -422,4 +500,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
" int slice_offset, int slice_size) -> Tensor");
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
ops.def(
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv,"
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
" str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
" Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0,"
" Tensor q_out1, Tensor kv_cache_out1)"
);
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
}

View File

@ -81,6 +81,41 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
return y_out;
}
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
const at::Tensor &hiddenState,
const at::Tensor &gamma0,
const at::Tensor &beta0,
const at::Tensor &wdqkv,
const at::Tensor &descale0,
const at::Tensor &gamma1,
const at::Tensor &beta1,
const at::Tensor &wuq,
const at::Tensor &descale1,
const at::Tensor &gamma2,
const at::Tensor &cos,
const at::Tensor &sin,
const at::Tensor &wuk,
const at::Tensor &kv_cache,
const at::Tensor &kv_cache_rope,
const at::Tensor &slotmapping,
const at::Tensor &quant_scale0,
const at::Tensor &quant_offset0,
const at::Tensor &bias0,
const at::Tensor &quant_scale1,
const at::Tensor &quant_offset1,
const at::Tensor &bias1,
const c10::optional<at::Tensor> &ctkv_scale,
const c10::optional<at::Tensor> &q_nope_scale,
c10::optional<c10::string_view> cache_mode,
c10::optional<c10::string_view> quant_mode,
at::Tensor &q_out0,
at::Tensor &kv_cache_out0,
at::Tensor &q_out1,
at::Tensor &kv_cache_out1)
{
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
}
} // namespace meta
} // namespace vllm_ascend
@ -97,6 +132,7 @@ namespace {
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
// Sgmv expand
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
// MLA preprocess
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
}
}

View File

@ -0,0 +1,112 @@
import gc
import torch
import torch_npu
from vllm_ascend.utils import enable_custom_op
enable_custom_op()
@torch.inference_mode()
def test_mla_preprocess_kernel():
token_num = 1
head_num = 2
N_7168 = 7168
block_num = 1
block_size = 128
dtype = torch.bfloat16
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
beta0 = torch.randn((N_7168), dtype=dtype).npu()
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
gamma1 = torch.randn((1536), dtype=dtype).npu()
beta1 = torch.randn((1536), dtype=dtype).npu()
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
dtype=torch.int8).npu()
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
gamma2 = torch.randn((512), dtype=dtype).npu()
cos = torch.randn((token_num, 64), dtype=dtype).npu()
sin = torch.randn((token_num, 64), dtype=dtype).npu()
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
wuk = torch_npu.npu_format_cast(wuk, 29)
kv_cache = torch.randint(0,
7,
(block_num, head_num * 512 // 32, block_size, 32),
dtype=dtype).npu()
kv_cache_rope = torch.randn(
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
q_nope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_rope_out = torch.empty(
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
q_nope_old = q_nope_out.clone()
q_rope_old = q_rope_out.clone()
torch.ops._C_ascend.mla_preprocess(
hidden_states,
gamma0,
beta0,
wdqkv,
de_scale0,
gamma1,
beta1,
wuq,
de_scale1,
gamma2,
cos,
sin,
wuk,
kv_cache,
kv_cache_rope,
slotmapping,
quant_scale0=quant_scale0,
quant_offset0=quant_offset0,
bias0=bias0,
quant_scale1=quant_scale1,
quant_offset1=quant_offset1,
bias1=bias1,
ctkv_scale=ctkv_scale,
q_nope_scale=qnope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=q_nope_out,
kv_cache_out0=kv_cache,
q_out1=q_rope_out,
kv_cache_out1=kv_cache_rope,
)
assert not torch.equal(q_nope_out, q_nope_old)
assert not torch.equal(q_rope_out, q_rope_old)
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()