mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 05:33:51 +08:00
add mla_preprocess kernel (#3226)
### What this PR does / why we need it? - Adds the `mla_preprocess` custom kernel to provide an optimized pre-processing operator for Multi-head Latent Attention (MLA) on Ascend NPUs. - Wires the new kernel into the C++ extension pipeline so vLLM can invoke it directly, cutting Python-side tensor shuffling and memory copies that previously bottlenecked MLA compilation paths. ### Does this PR introduce any user-facing change? - No. The change only introduces a low-level kernel; public APIs and inference behavior remain unchanged. ### How was this patch tested? - Dedicated Ascend kernels are not covered by our CI yet, so no extra automated tests were added. Future MLA-focused regression runs will cover this path. - vLLM version: v0.11.0 Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
@ -12,8 +12,8 @@ repos:
|
||||
- id: codespell
|
||||
args: [
|
||||
--toml, pyproject.toml,
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn'
|
||||
'--skip', 'tests/e2e/multicard/test_torchair_graph_mode.py,csrc/mla_preprocess/**,tests/prompts/**,./benchmarks/sonnet.txt,*tests/lora/data/**,build/**,./vllm_ascend.egg-info/**,.github/**,typos.toml',
|
||||
'-L', 'CANN,cann,NNAL,nnal,ASCEND,ascend,EnQue,CopyIn,ArchType,AND'
|
||||
]
|
||||
additional_dependencies:
|
||||
- tomli
|
||||
@ -35,6 +35,10 @@ repos:
|
||||
rev: v1.32.0
|
||||
hooks:
|
||||
- id: typos
|
||||
args: [
|
||||
"--force-exclude",
|
||||
"--exclude", "csrc/mla_preprocess/**"
|
||||
]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
|
@ -44,11 +44,13 @@ else()
|
||||
endif()
|
||||
|
||||
include(${ASCENDC_CMAKE_DIR}/ascendc.cmake)
|
||||
|
||||
file(GLOB KERNEL_FILES
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/kernels/*.cpp)
|
||||
|
||||
ascendc_library(vllm_ascend_kernels SHARED
|
||||
${KERNEL_FILES}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
|
||||
)
|
||||
|
||||
message("TORCH_NPU_PATH is ${TORCH_NPU_PATH}")
|
||||
@ -90,7 +92,11 @@ target_link_libraries(
|
||||
libtorch_npu.so
|
||||
vllm_ascend_kernels
|
||||
ascendcl
|
||||
tiling_api
|
||||
register
|
||||
platform
|
||||
ascendalog
|
||||
dl
|
||||
)
|
||||
|
||||
target_link_options(vllm_ascend_C PRIVATE "-Wl,-rpath,$ORIGIN:$ORIGIN/lib")
|
||||
|
698
csrc/mla_preprocess/op_host/mla_preprocess.h
Normal file
698
csrc/mla_preprocess/op_host/mla_preprocess.h
Normal file
@ -0,0 +1,698 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
// https://gitee.com/ascend/op-plugin.git
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <math.h>
|
||||
#include <stdexcept>
|
||||
#include "acl/acl.h"
|
||||
// #include "defines.h"
|
||||
// #include "torch_helper.h"
|
||||
#include "tiling/platform/platform_ascendc.h"
|
||||
#include "tiling/mla_preprocess_tiling.h"
|
||||
|
||||
// #include "aclrtlaunch_mla_preprocess.h"
|
||||
|
||||
// namespace sglang {
|
||||
namespace mlapo {
|
||||
|
||||
constexpr uint32_t DIM_2 = 2;
|
||||
|
||||
constexpr uint32_t AXES_ALIGN_SIZE = 512;
|
||||
constexpr uint32_t BASE_BLOCK_STEP = 2;
|
||||
constexpr uint32_t CONST_16 = 16;
|
||||
constexpr uint32_t CONST_32 = 32;
|
||||
constexpr uint32_t CONST_128 = 128;
|
||||
constexpr uint32_t CONST_256 = 256;
|
||||
constexpr uint32_t CONST_512 = 512;
|
||||
constexpr uint32_t L1_BUFFER_SIZE = 524288;
|
||||
constexpr uint32_t L1_PINGPONG_BUFFER_LEN = 262144;
|
||||
constexpr uint32_t L0AB_PINGPONG_BUFFER_LEN = 131072;
|
||||
constexpr uint32_t L1_SCALE_SIZE = 4096;
|
||||
constexpr uint32_t L1_BIAS_SIZE = 2048;
|
||||
constexpr uint32_t L0C_SIZE = 128 * 1024;
|
||||
constexpr uint32_t CONCAT_SIZE = 512;
|
||||
|
||||
constexpr uint32_t HIDDEN_STRATE = 7168;
|
||||
constexpr uint32_t HIDDEN_STRATE_ROPE = 192;
|
||||
constexpr uint32_t HIDDEN_STRATE_MM = 2112;
|
||||
constexpr uint32_t HIDDEN_STRATE_RMS = 1536;
|
||||
constexpr uint32_t UB_SIZE = 196352;
|
||||
constexpr uint32_t HEADDIM = 64;
|
||||
constexpr uint32_t FP32_REPEAT_MASK = 64;
|
||||
constexpr uint32_t FP16_REPEAT_MASK = 128;
|
||||
|
||||
constexpr int32_t NUM1 = 1;
|
||||
constexpr int32_t NUM2 = 2;
|
||||
constexpr int32_t NUM3 = 3;
|
||||
constexpr int32_t NUM4 = 4;
|
||||
constexpr int32_t NUM8 = 8;
|
||||
constexpr uint32_t INDEX_WDQKV = 5;
|
||||
constexpr uint32_t INDEX_WUQ = 18;
|
||||
constexpr uint32_t INDEX_WUK = 20;
|
||||
|
||||
constexpr uint32_t MAX_SUPPORT_TOKEN_NUMS = 1024;
|
||||
|
||||
inline uint32_t CeilDiv(const uint32_t dividend, const uint32_t divisor)
|
||||
{
|
||||
if (divisor == 0) {
|
||||
return UINT32_MAX;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
inline uint32_t RoundUp(const uint32_t val, const uint32_t align = 16)
|
||||
{
|
||||
if (align == 0) {
|
||||
return 0;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
inline uint32_t RoundDown(const uint32_t val, const uint32_t align = 16)
|
||||
{
|
||||
if (align == 0) {
|
||||
return 0;
|
||||
}
|
||||
return val / align * align;
|
||||
}
|
||||
|
||||
template <typename T = uint32_t>
|
||||
inline T Max(const T a, const T b)
|
||||
{
|
||||
return a > b ? a : b;
|
||||
}
|
||||
|
||||
template <typename T = uint32_t>
|
||||
inline T Min(const T a, const T b)
|
||||
{
|
||||
return a < b ? a : b;
|
||||
}
|
||||
|
||||
struct MlaPreprocess {
|
||||
enum class QuantMode : int32_t {
|
||||
PER_TENSOR_ASYMM_QUANT = 0,
|
||||
PER_TOKEN_SYMM_QUANT,
|
||||
PER_TOKEN_ASYMM_QUANT,
|
||||
NO_QUANT
|
||||
};
|
||||
};
|
||||
using QuantMode = MlaPreprocess::QuantMode;
|
||||
|
||||
struct PlatformInfo {
|
||||
uint32_t coreNum;
|
||||
uint32_t coreNumAic;
|
||||
uint32_t coreNumAiv;
|
||||
uint64_t ubSize;
|
||||
uint64_t l1Size;
|
||||
uint64_t l2Size;
|
||||
uint64_t l0aSize;
|
||||
uint64_t l0bSize;
|
||||
uint64_t l0cSize;
|
||||
};
|
||||
|
||||
struct OpParam {
|
||||
uint32_t N;
|
||||
uint32_t headNum;
|
||||
int32_t cacheMode;
|
||||
QuantMode quantMode;
|
||||
caffe2::TypeMeta inDtype;
|
||||
};
|
||||
|
||||
class PpMatmulTilingApi
|
||||
{
|
||||
public:
|
||||
PpMatmulTilingApi(struct PlatformInfo &platformInfo, uint32_t numBatch, uint32_t m, uint32_t k, uint32_t n,
|
||||
bool transA, bool transB, bool enDequant, bool deqOnTheFly)
|
||||
: platformInfo_(platformInfo),
|
||||
numBatch_(numBatch),
|
||||
m_(m),
|
||||
k_(k),
|
||||
n_(n),
|
||||
transA_(transA),
|
||||
transB_(transB),
|
||||
enDequant_(enDequant),
|
||||
deqOnTheFly_(deqOnTheFly)
|
||||
{
|
||||
inDataSize_ = enDequant ? sizeof(uint8_t) : sizeof(uint16_t);
|
||||
}
|
||||
void GetTilingData(PpMatmulTilingData &tiling);
|
||||
|
||||
private:
|
||||
void GetTileSize();
|
||||
float GetCost(const uint32_t m0, const uint32_t n0);
|
||||
void UpdateTileSize(const uint32_t m0, const uint32_t n0);
|
||||
void Swizzle();
|
||||
uint32_t ComputeL1AbSize();
|
||||
uint32_t ComputeK0ForABpingpong(uint32_t l1AbSize);
|
||||
bool IsLoadAllAmat(uint32_t l1AbSize);
|
||||
uint32_t ComputeK0ForOnlyBpingpong(uint32_t l1AbSize);
|
||||
|
||||
private:
|
||||
uint32_t numBatch_{0};
|
||||
uint32_t m_{0};
|
||||
uint32_t k_{0};
|
||||
uint32_t n_{0};
|
||||
uint32_t m0_{0};
|
||||
uint32_t k0_{0};
|
||||
uint32_t n0_{0};
|
||||
uint32_t mLoop_{0};
|
||||
uint32_t kLoop_{0};
|
||||
uint32_t nLoop_{0};
|
||||
uint32_t coreLoop_{0};
|
||||
uint32_t swizzleCount_{0};
|
||||
uint32_t blockDim_{0};
|
||||
uint32_t swizzleDirect_{0};
|
||||
uint32_t inDataSize_{0};
|
||||
uint32_t b0matPingPongBufferLen_{L1_PINGPONG_BUFFER_LEN};
|
||||
bool transA_{false};
|
||||
bool transB_{false};
|
||||
bool enDequant_{false};
|
||||
bool enShuffleK_{false};
|
||||
bool enLoadAllAmat_{false};
|
||||
bool deqOnTheFly_{false};
|
||||
|
||||
struct PlatformInfo platformInfo_;
|
||||
};
|
||||
|
||||
void PpMatmulTilingApi::GetTilingData(PpMatmulTilingData &tiling)
|
||||
{
|
||||
GetTileSize();
|
||||
tiling.numBatch = numBatch_;
|
||||
tiling.m = m_;
|
||||
tiling.k = k_;
|
||||
tiling.n = n_;
|
||||
tiling.m0 = m0_;
|
||||
tiling.k0 = k0_;
|
||||
tiling.n0 = n0_;
|
||||
tiling.mLoop = mLoop_;
|
||||
tiling.kLoop = kLoop_;
|
||||
tiling.nLoop = nLoop_;
|
||||
tiling.coreLoop = coreLoop_;
|
||||
tiling.swizzleCount = swizzleCount_;
|
||||
tiling.swizzleDirect = swizzleDirect_;
|
||||
tiling.enShuffleK = static_cast<uint32_t>(enShuffleK_);
|
||||
tiling.blockDim = blockDim_;
|
||||
tiling.enLoadAllAmat = static_cast<uint32_t>(enLoadAllAmat_);
|
||||
tiling.b0matPingPongBufferLen = b0matPingPongBufferLen_;
|
||||
}
|
||||
|
||||
void PpMatmulTilingApi::GetTileSize()
|
||||
{
|
||||
bool priFlag = !(m_ < n_);
|
||||
uint32_t roundBase = pow(2, ceil(log(CeilDiv(priFlag ? n_ : m_, CONST_16)))) * CONST_16;
|
||||
uint32_t priAxes = RoundUp(priFlag ? m_ : n_, CONST_16);
|
||||
uint32_t subAxes = RoundUp(priFlag ? n_ : m_, roundBase);
|
||||
float minCost = __FLT_MAX__;
|
||||
uint32_t maxAxes0 = AXES_ALIGN_SIZE;
|
||||
uint32_t maxPriAxes0 = Min(maxAxes0, priAxes);
|
||||
uint32_t maxSubAxes0 = Min(maxAxes0, subAxes);
|
||||
for (uint32_t priAxes0 = CONST_16; priAxes0 <= maxPriAxes0; priAxes0 *= BASE_BLOCK_STEP) {
|
||||
for (uint32_t subAxes0 = CONST_16; subAxes0 <= maxSubAxes0; subAxes0 *= BASE_BLOCK_STEP) {
|
||||
if (priAxes0 * subAxes0 * sizeof(float) > platformInfo_.l0cSize) {
|
||||
continue;
|
||||
}
|
||||
uint32_t newM0 = priFlag ? priAxes0 : subAxes0;
|
||||
uint32_t newN0 = priFlag ? subAxes0 : priAxes0;
|
||||
if (newN0 > CONST_256 && enDequant_) {
|
||||
continue;
|
||||
}
|
||||
float cost = GetCost(newM0, newN0);
|
||||
if (cost < minCost) {
|
||||
minCost = cost;
|
||||
UpdateTileSize(newM0, newN0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Swizzle();
|
||||
|
||||
uint32_t l1AbSize = ComputeL1AbSize();
|
||||
k0_ = ComputeK0ForABpingpong(l1AbSize);
|
||||
kLoop_ = CeilDiv(k_, k0_);
|
||||
}
|
||||
|
||||
uint32_t PpMatmulTilingApi::ComputeK0ForOnlyBpingpong(uint32_t l1AbSize)
|
||||
{
|
||||
enLoadAllAmat_ = true;
|
||||
b0matPingPongBufferLen_ = static_cast<uint32_t>(
|
||||
static_cast<float>((l1AbSize - RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) / DIM_2));
|
||||
uint32_t k0MaxB0 =
|
||||
static_cast<uint32_t>(static_cast<float>(b0matPingPongBufferLen_ / (RoundUp(n0_, CONST_16) * inDataSize_)));
|
||||
uint32_t k0B0 = k0MaxB0 < CONST_512 ? RoundDown(k0MaxB0, CONST_32) : RoundDown(k0MaxB0, CONST_512);
|
||||
return k0B0 > CONST_512 ? RoundDown(k0B0, CONST_512) : k0B0;
|
||||
}
|
||||
|
||||
bool PpMatmulTilingApi::IsLoadAllAmat(uint32_t l1AbSize)
|
||||
{
|
||||
return (coreLoop_ > blockDim_) && enDequant_ && (kLoop_ > 1) &&
|
||||
(l1AbSize > RoundUp(m_, CONST_16) * RoundUp(k_, CONST_32) * inDataSize_) && (mLoop_ == 1);
|
||||
}
|
||||
|
||||
uint32_t PpMatmulTilingApi::ComputeK0ForABpingpong(uint32_t l1AbSize)
|
||||
{
|
||||
uint32_t k0Max = static_cast<uint32_t>(static_cast<float>(l1AbSize / DIM_2) / ((m0_ + n0_) * inDataSize_));
|
||||
uint32_t tmpK0;
|
||||
if (enDequant_) {
|
||||
tmpK0 = k0Max < CONST_512 ? RoundDown(k0Max, CONST_32) : RoundDown(k0Max, CONST_512);
|
||||
} else {
|
||||
tmpK0 = k0Max < CONST_256 ? RoundDown(k0Max, CONST_16) : RoundDown(k0Max, CONST_256);
|
||||
}
|
||||
if (tmpK0 > CONST_512) {
|
||||
tmpK0 = RoundDown(tmpK0, CONST_512);
|
||||
}
|
||||
return tmpK0;
|
||||
}
|
||||
|
||||
uint32_t PpMatmulTilingApi::ComputeL1AbSize()
|
||||
{
|
||||
if (enDequant_ && deqOnTheFly_) {
|
||||
return L1_BUFFER_SIZE;
|
||||
}
|
||||
return enDequant_ ? (L1_BUFFER_SIZE - L1_BIAS_SIZE - L1_SCALE_SIZE) : L1_BUFFER_SIZE;
|
||||
}
|
||||
|
||||
float PpMatmulTilingApi::GetCost(const uint32_t m0, const uint32_t n0)
|
||||
{
|
||||
float aCoef = 1.0;
|
||||
float bCoef = 1.0;
|
||||
float bwCoef = 5.0;
|
||||
uint32_t mLoop = CeilDiv(m_, m0);
|
||||
uint32_t nLoop = CeilDiv(n_, n0);
|
||||
if (mLoop == 0 || nLoop == 0) {
|
||||
return __FLT_MAX__;
|
||||
}
|
||||
uint32_t rqdNumCore = numBatch_ * mLoop * nLoop;
|
||||
uint32_t blockDim = Min(rqdNumCore, platformInfo_.coreNumAic);
|
||||
uint32_t mOnce = blockDim < nLoop ? m0 : blockDim / nLoop * m0;
|
||||
uint32_t nOnce = blockDim < nLoop ? platformInfo_.coreNumAic * n0 : n_;
|
||||
if (mOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
|
||||
aCoef = bwCoef;
|
||||
}
|
||||
if (nOnce * k_ * sizeof(uint16_t) > platformInfo_.l2Size) {
|
||||
bCoef = bwCoef;
|
||||
}
|
||||
if (transA_ && m0 % CONST_256 == 0) {
|
||||
aCoef *= NUM2;
|
||||
}
|
||||
if (!transB_ && n0 % CONST_256 == 0) {
|
||||
bCoef *= NUM2;
|
||||
}
|
||||
return 1 / (aCoef * static_cast<float>(n0)) + 1 / (bCoef * static_cast<float>(m0));
|
||||
}
|
||||
|
||||
void PpMatmulTilingApi::UpdateTileSize(const uint32_t m0, const uint32_t n0)
|
||||
{
|
||||
m0_ = m0;
|
||||
n0_ = n0;
|
||||
mLoop_ = CeilDiv(m_, m0_);
|
||||
nLoop_ = CeilDiv(n_, n0_);
|
||||
coreLoop_ = numBatch_ * mLoop_ * nLoop_;
|
||||
const uint32_t maxNumCubeCore = platformInfo_.coreNumAic;
|
||||
if (mLoop_ == 1 && transB_ && coreLoop_ % maxNumCubeCore < maxNumCubeCore / NUM4 * NUM3) {
|
||||
uint32_t tmpM0 = RoundUp(m_, CONST_16);
|
||||
uint32_t maxN0 = L0C_SIZE / (tmpM0 * sizeof(float));
|
||||
if (enDequant_) {
|
||||
maxN0 = maxN0 < CONST_256 ? maxN0 : CONST_256;
|
||||
}
|
||||
uint32_t x = CeilDiv(n_, maxNumCubeCore);
|
||||
uint32_t y = CeilDiv(x, maxN0);
|
||||
uint32_t tmpN0 = RoundUp(CeilDiv(x, y), CONST_16);
|
||||
uint32_t rqdL0cSize = tmpM0 * tmpN0 * sizeof(float);
|
||||
if (rqdL0cSize < L0C_SIZE && (tmpM0 + tmpN0) * CONST_256 * inDataSize_ < L1_BUFFER_SIZE) {
|
||||
m0_ = tmpM0;
|
||||
n0_ = tmpN0;
|
||||
nLoop_ = CeilDiv(n_, n0_);
|
||||
coreLoop_ = numBatch_ * nLoop_;
|
||||
}
|
||||
}
|
||||
blockDim_ = Min(coreLoop_, maxNumCubeCore);
|
||||
}
|
||||
|
||||
void PpMatmulTilingApi::Swizzle()
|
||||
{
|
||||
float minCost = m_ * k_ + k_ * n_;
|
||||
for (uint32_t i = 1; i <= blockDim_; ++i) {
|
||||
int c = static_cast<int32_t>((blockDim_ + i - 1) / i);
|
||||
float cost;
|
||||
// B0 + A < A0 + B
|
||||
if (i * n0_ + m_ < m0_ * c + n_) {
|
||||
swizzleDirect_ = 1; // Nz
|
||||
cost = n0_ * i + m0_ * c;
|
||||
if (cost <= minCost) {
|
||||
minCost = cost;
|
||||
swizzleCount_ = i;
|
||||
}
|
||||
} else {
|
||||
swizzleDirect_ = 0; // Zn
|
||||
cost = m0_ * i + n0_ * c;
|
||||
if (cost < minCost) {
|
||||
minCost = cost;
|
||||
swizzleCount_ = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class MlaPreprocessTiling
|
||||
{
|
||||
public:
|
||||
MlaPreprocessTiling(struct PlatformInfo &platformInfo, struct OpParam &opParam, MlaTilingData *tilingData)
|
||||
{
|
||||
this->tilingData = tilingData;
|
||||
this->platformInfo = platformInfo;
|
||||
this->opParam = opParam;
|
||||
}
|
||||
void Init();
|
||||
|
||||
void RmsNormQuantTiling();
|
||||
void RopeConcatTiling();
|
||||
void EinSumQuantTiling();
|
||||
|
||||
void SetTilingKey();
|
||||
void SetMlapoWorkSpace();
|
||||
|
||||
private:
|
||||
MlaTilingData *tilingData;
|
||||
struct PlatformInfo platformInfo;
|
||||
struct OpParam opParam;
|
||||
};
|
||||
|
||||
void MlaPreprocessTiling::RmsNormQuantTiling()
|
||||
{
|
||||
tilingData->rmsNumCore1 = platformInfo.coreNumAiv;
|
||||
tilingData->rmsNumCol1 = HIDDEN_STRATE;
|
||||
tilingData->rmsNumRow1 = opParam.N;
|
||||
tilingData->rmsQuantMin1 = -CONST_128;
|
||||
tilingData->rmsNumCore2 = platformInfo.coreNumAiv;
|
||||
tilingData->rmsNumCol2 = HIDDEN_STRATE_MM;
|
||||
tilingData->rmsNumRow2 = opParam.N;
|
||||
tilingData->rmsQuantMin2 = -CONST_128;
|
||||
}
|
||||
|
||||
void MlaPreprocessTiling::RopeConcatTiling()
|
||||
{
|
||||
uint32_t ntokens = opParam.N;
|
||||
uint32_t hiddenSizeQ = HEADDIM * opParam.headNum;
|
||||
uint32_t headDim = HEADDIM;
|
||||
uint32_t headNumQ = hiddenSizeQ / headDim;
|
||||
uint32_t concatSize = CONCAT_SIZE;
|
||||
uint32_t maxCore = platformInfo.coreNumAiv;
|
||||
uint32_t maxUbSize = platformInfo.ubSize;
|
||||
|
||||
uint32_t allHeadNum = ntokens * headNumQ;
|
||||
|
||||
uint32_t tempCore = (allHeadNum + maxCore - 1) / maxCore;
|
||||
uint32_t realCore = (allHeadNum + tempCore - 1) / tempCore; // Actual number of the core for operation
|
||||
uint32_t nlCoreRun = (allHeadNum + realCore - 1) / realCore; // The number of heads in the front core
|
||||
uint32_t lCoreRun = allHeadNum - (realCore - 1) * nlCoreRun; // The number of heads in the tail core
|
||||
|
||||
uint32_t dataTypeSize = 2;
|
||||
|
||||
// Calculate how many lines can be moved at a time. q 4+2、reverseq 4、neg 4、sin 4+2、cos 4+2 + concat 2
|
||||
uint32_t allSize =
|
||||
headDim * (3 * (4 + dataTypeSize) + 2 * 4) + concatSize * dataTypeSize; // lift precision calculation of ROPE
|
||||
uint32_t maxNPerLoopForUb = maxUbSize / allSize; // the maximum number of rows at a time for UB
|
||||
uint32_t preCoreLoopTime = (nlCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of front core
|
||||
uint32_t preCoreLoopNLast =
|
||||
nlCoreRun -
|
||||
(preCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the front core
|
||||
uint32_t lastCoreLoopTime = (lCoreRun + maxNPerLoopForUb - 1) / maxNPerLoopForUb; // Number of cycles of tail core
|
||||
uint32_t lastCoreLoopNLast =
|
||||
lCoreRun -
|
||||
(lastCoreLoopTime - 1) * maxNPerLoopForUb; // rows of data processed in the last batch of the tail core
|
||||
|
||||
tilingData->hiddenSizeQ = hiddenSizeQ;
|
||||
tilingData->headNumQ = headNumQ;
|
||||
tilingData->headDim = headDim;
|
||||
tilingData->concatSize = concatSize;
|
||||
tilingData->rotaryCoeff = NUM2;
|
||||
tilingData->ntokens = ntokens;
|
||||
tilingData->realCore = realCore;
|
||||
tilingData->nlCoreRun = nlCoreRun;
|
||||
tilingData->lCoreRun = nlCoreRun;
|
||||
tilingData->maxNPerLoopForUb = maxNPerLoopForUb;
|
||||
tilingData->preCoreLoopTime = preCoreLoopTime;
|
||||
tilingData->preCoreLoopNLast = preCoreLoopNLast;
|
||||
tilingData->lastCoreLoopTime = lastCoreLoopTime;
|
||||
tilingData->lastCoreLoopNLast = lastCoreLoopNLast;
|
||||
}
|
||||
|
||||
void MlaPreprocessTiling::EinSumQuantTiling()
|
||||
{
|
||||
uint32_t aivCore = platformInfo.coreNumAiv;
|
||||
uint32_t ubSize = UB_SIZE - 1024;
|
||||
|
||||
// input shape
|
||||
uint32_t esqBatch = opParam.N; // tokenNum
|
||||
uint32_t esqHeadNum = opParam.headNum; // headNum
|
||||
uint32_t esqColNum = AXES_ALIGN_SIZE; // 512
|
||||
|
||||
// split core
|
||||
uint32_t esqFrontCore = esqBatch % aivCore;
|
||||
uint32_t esqTailCore = aivCore - esqFrontCore;
|
||||
uint32_t esqFrontCoreBatch = CeilDiv(esqBatch, aivCore);
|
||||
uint32_t esqTailCoreBatch = esqBatch / aivCore;
|
||||
|
||||
// split ub --> calc H' <-- The number of rows handled in a UB cycle.
|
||||
uint32_t splitFactor = 0;
|
||||
uint32_t esqHeadPerLoop = 0; // The number of head rows per UB calculation
|
||||
uint32_t repeatMask = 0;
|
||||
|
||||
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
// Move scales in at once, broadcast, and cache them all H * 32bytes
|
||||
uint32_t scaleUb = RoundUp(esqHeadNum) * CONST_32;
|
||||
// bf16 input [H', colNum](f16 + fp32 + int8), ub reuse
|
||||
splitFactor = esqColNum * (sizeof(uint16_t) + sizeof(float) + sizeof(uint8_t));
|
||||
splitFactor *= NUM2;
|
||||
esqHeadPerLoop = (ubSize - scaleUb) / splitFactor; // 26
|
||||
repeatMask = FP32_REPEAT_MASK;
|
||||
} else {
|
||||
// fp16 input [H', cloNum](fp16*2 + int8) + [H', 1](fp16) + [H', 16](fp16)
|
||||
splitFactor =
|
||||
esqColNum * (NUM2 * sizeof(uint16_t) + sizeof(uint8_t)) + sizeof(uint16_t) + (CONST_16 * sizeof(uint16_t));
|
||||
esqHeadPerLoop = ubSize / splitFactor;
|
||||
repeatMask = FP16_REPEAT_MASK;
|
||||
esqHeadPerLoop = RoundDown(esqHeadPerLoop);
|
||||
}
|
||||
uint32_t esqUbHeadLoop = esqHeadNum / esqHeadPerLoop; // UB complete cycles
|
||||
uint32_t esqHeadTail = esqHeadNum % esqHeadPerLoop; // The number of rows that UB last processed the head.
|
||||
uint32_t esqColLoop = esqColNum / repeatMask; // Each row counts the number of times to cycle through columns.
|
||||
uint32_t esqColTail =
|
||||
esqColNum % repeatMask; // colNum is not 64/128 aligned, the number of columns is calculated last.
|
||||
|
||||
tilingData->esqFrontCore = esqFrontCore;
|
||||
tilingData->esqTailCore = esqTailCore;
|
||||
tilingData->esqFrontCoreBatch = esqFrontCoreBatch;
|
||||
tilingData->esqTailCoreBatch = esqTailCoreBatch;
|
||||
tilingData->esqHeadNum = esqHeadNum;
|
||||
tilingData->esqColNum = esqColNum;
|
||||
tilingData->esqUbHeadLoop = esqUbHeadLoop;
|
||||
tilingData->esqHeadPerLoop = esqHeadPerLoop;
|
||||
tilingData->esqHeadTail = esqHeadTail;
|
||||
tilingData->esqColLoop = esqColLoop;
|
||||
tilingData->esqColTail = esqColTail;
|
||||
}
|
||||
|
||||
void MlaPreprocessTiling::SetMlapoWorkSpace()
|
||||
{
|
||||
uint64_t s1wsFactor =
|
||||
static_cast<uint64_t>(opParam.cacheMode == 2 ? std::max(HIDDEN_STRATE * sizeof(int8_t),
|
||||
opParam.headNum * AXES_ALIGN_SIZE * sizeof(uint16_t))
|
||||
: HIDDEN_STRATE * sizeof(int8_t));
|
||||
uint64_t workSizeS1 = s1wsFactor;
|
||||
uint64_t workSizeS2 = opParam.headNum * HIDDEN_STRATE_ROPE * sizeof(uint16_t);
|
||||
uint64_t workSizeS3 = HIDDEN_STRATE_MM * sizeof(uint16_t);
|
||||
uint64_t workSizeS4 = std::max(opParam.headNum * HIDDEN_STRATE_ROPE, HIDDEN_STRATE_MM) * sizeof(uint32_t);
|
||||
|
||||
uint64_t maxWorkspaceSize = workSizeS1;
|
||||
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS2);
|
||||
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS3);
|
||||
maxWorkspaceSize = std::max(maxWorkspaceSize, workSizeS4);
|
||||
maxWorkspaceSize *= static_cast<uint64_t>(opParam.N);
|
||||
|
||||
uint64_t pertokenWorkspace = static_cast<uint64_t>(opParam.N) * sizeof(float) * 2;
|
||||
|
||||
uint64_t userWorkspaceSize;
|
||||
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
userWorkspaceSize = 4 * maxWorkspaceSize + pertokenWorkspace;
|
||||
} else {
|
||||
userWorkspaceSize = 3 * maxWorkspaceSize;
|
||||
}
|
||||
|
||||
tilingData->userWorkspaceSize = userWorkspaceSize;
|
||||
tilingData->s1Offset = 0;
|
||||
tilingData->s2Offset = tilingData->s1Offset + maxWorkspaceSize;
|
||||
tilingData->s3Offset = tilingData->s2Offset + maxWorkspaceSize;
|
||||
tilingData->s4Offset = tilingData->s3Offset + maxWorkspaceSize;
|
||||
tilingData->s5Offset = tilingData->s4Offset + maxWorkspaceSize;
|
||||
}
|
||||
|
||||
void MlaPreprocessTiling::SetTilingKey()
|
||||
{
|
||||
uint64_t tilingKey = (static_cast<uint64_t>(opParam.inDtype == at::kBFloat16)) << 8;
|
||||
|
||||
tilingKey |= static_cast<uint64_t>(opParam.cacheMode);
|
||||
tilingKey |= (static_cast<uint64_t>(opParam.quantMode) << 3);
|
||||
|
||||
tilingData->tilingKey = tilingKey;
|
||||
}
|
||||
|
||||
void MlaPreprocessTiling::Init()
|
||||
{
|
||||
tilingData->numCore = platformInfo.coreNumAic;
|
||||
tilingData->n = opParam.N;
|
||||
|
||||
bool deqOnTheFly = false;
|
||||
if (opParam.inDtype == at::kBFloat16 || opParam.quantMode == QuantMode::PER_TOKEN_SYMM_QUANT) {
|
||||
deqOnTheFly = true;
|
||||
}
|
||||
|
||||
PpMatmulTilingApi mm1TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
HIDDEN_STRATE, // k
|
||||
HIDDEN_STRATE_MM, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
true, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm1TilingApi.GetTilingData(tilingData->mm1);
|
||||
|
||||
PpMatmulTilingApi mm2TilingApi(platformInfo,
|
||||
1, // numBatch
|
||||
opParam.N, // m
|
||||
HIDDEN_STRATE_RMS, // k
|
||||
opParam.headNum * HIDDEN_STRATE_ROPE, // n
|
||||
false, // transA
|
||||
true, // transB
|
||||
true, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm2TilingApi.GetTilingData(tilingData->mm2);
|
||||
|
||||
PpMatmulTilingApi mm3TilingApi(platformInfo,
|
||||
opParam.headNum, // numBatch
|
||||
opParam.N, // m
|
||||
CONST_128, // k
|
||||
CONCAT_SIZE, // n
|
||||
false, // transA
|
||||
false, // transB
|
||||
false, // enDequant
|
||||
deqOnTheFly); // in bf16.cce?
|
||||
mm3TilingApi.GetTilingData(tilingData->mm3);
|
||||
|
||||
RmsNormQuantTiling();
|
||||
RopeConcatTiling();
|
||||
EinSumQuantTiling();
|
||||
|
||||
SetMlapoWorkSpace();
|
||||
SetTilingKey();
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
std::unordered_map<c10::string_view, uint16_t> cache_mode_map = {
|
||||
{"krope_ctkv", 1}, {"int8_nzcache", 2}, {"nzcache", 3}};
|
||||
|
||||
std::unordered_map<c10::string_view, uint16_t> quant_mode_map = {
|
||||
{"per_tensor_quant_asymm", 0},
|
||||
{"per_token_quant_symm", 1},
|
||||
};
|
||||
|
||||
template <typename MapType>
|
||||
inline int get_op_mode(const MapType &mode_map, c10::optional<c10::string_view> mode_opt, c10::string_view default_mode,
|
||||
const char *mode_name)
|
||||
{
|
||||
c10::string_view mode_str = mode_opt.value_or(default_mode);
|
||||
auto it = mode_map.find(mode_str);
|
||||
TORCH_CHECK(it != mode_map.end(), "Unsupported ", mode_name, " value: '", mode_str, "'");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
// std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
// const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
|
||||
// const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
||||
// const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
||||
// const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
||||
// const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
|
||||
// const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
|
||||
// const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
|
||||
// c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
|
||||
// at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
|
||||
std::tuple<at::Tensor, at::Tensor, uint32_t> mla_preprocess_tiling(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &wuk,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode
|
||||
)
|
||||
{
|
||||
auto cacheMode = get_op_mode(cache_mode_map, cache_mode, "krope_ctkv", "cache_mode");
|
||||
auto quantMode = get_op_mode(quant_mode_map, quant_mode, "per_token_quant_symm", "quant_mode");
|
||||
|
||||
platform_ascendc::PlatformAscendC *platformAscendC = platform_ascendc::PlatformAscendCManager::GetInstance();
|
||||
|
||||
struct PlatformInfo platformInfo;
|
||||
platformInfo.coreNum = platformAscendC->GetCoreNum();
|
||||
platformInfo.coreNumAic = platformAscendC->GetCoreNumAic();
|
||||
platformInfo.coreNumAiv = platformAscendC->GetCoreNumAiv();
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::UB, platformInfo.ubSize);
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L1, platformInfo.l1Size);
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L2, platformInfo.l2Size);
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_A, platformInfo.l0aSize);
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_B, platformInfo.l0bSize);
|
||||
platformAscendC->GetCoreMemSize(platform_ascendc::CoreMemType::L0_C, platformInfo.l0cSize);
|
||||
|
||||
int32_t N = hiddenState.sizes()[0];
|
||||
int32_t headNum = wuk.sizes()[0];
|
||||
|
||||
OpParam opParam;
|
||||
opParam.N = N;
|
||||
opParam.headNum = headNum;
|
||||
opParam.cacheMode = static_cast<int32_t>(cacheMode);
|
||||
opParam.quantMode = static_cast<QuantMode>(quantMode);
|
||||
opParam.inDtype = hiddenState.options().dtype();
|
||||
|
||||
MlaTilingData tilingData;
|
||||
MlaPreprocessTiling mlaTiling(platformInfo, opParam, &tilingData);
|
||||
|
||||
mlaTiling.Init();
|
||||
uint32_t blockDim = platformInfo.coreNumAic;
|
||||
|
||||
// workspace
|
||||
uint64_t system_workspace_size = static_cast<uint64_t>(platformAscendC->GetLibApiWorkSpaceSize());
|
||||
uint64_t workspace_size = system_workspace_size + tilingData.userWorkspaceSize;
|
||||
auto options = at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device());
|
||||
auto workspace_tensor = at::empty({static_cast<int64_t>(workspace_size)}, options);
|
||||
|
||||
// tiling
|
||||
int32_t bIndex = N - 1;
|
||||
uint32_t tilingSize = sizeof(MlaTilingData);
|
||||
static auto global_tiling_data =
|
||||
at::empty({tilingSize * MAX_SUPPORT_TOKEN_NUMS},
|
||||
at::TensorOptions().dtype(at::kByte).device(hiddenState.options().device()));
|
||||
if (bIndex >= 0 && bIndex < MAX_SUPPORT_TOKEN_NUMS) {
|
||||
aclrtMemcpy(global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex), tilingSize, &tilingData, tilingSize,
|
||||
ACL_MEMCPY_HOST_TO_DEVICE);
|
||||
} else {
|
||||
// Handle the case where bIndex is out of range
|
||||
TORCH_CHECK(false, "bIndex is out of range: ", bIndex);
|
||||
}
|
||||
at::Tensor tiling = at::from_blob(
|
||||
global_tiling_data.data_ptr<uint8_t>() + (tilingSize * bIndex),
|
||||
tilingSize,
|
||||
at::kByte);
|
||||
|
||||
return std::make_tuple(workspace_tensor, tiling, blockDim);
|
||||
}
|
||||
|
||||
} // namespace npu_kernel
|
95
csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h
Normal file
95
csrc/mla_preprocess/op_host/tiling/mla_preprocess_tiling.h
Normal file
@ -0,0 +1,95 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#ifndef MLAPREPROCESS_TILING_H
|
||||
#define MLAPREPROCESS_TILING_H
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
struct PpMatmulTilingData {
|
||||
uint32_t numBatch{0};
|
||||
uint32_t m{0};
|
||||
uint32_t k{0};
|
||||
uint32_t n{0};
|
||||
uint32_t m0{0};
|
||||
uint32_t k0{0};
|
||||
uint32_t n0{0};
|
||||
uint32_t mLoop{0};
|
||||
uint32_t kLoop{0};
|
||||
uint32_t nLoop{0};
|
||||
uint32_t coreLoop{0};
|
||||
uint32_t swizzleCount{0};
|
||||
uint32_t swizzleDirect{0};
|
||||
uint32_t enShuffleK{0};
|
||||
uint32_t blockDim{0};
|
||||
uint32_t enLoadAllAmat{0};
|
||||
uint32_t b0matPingPongBufferLen{0};
|
||||
};
|
||||
|
||||
struct MlaTilingData {
|
||||
uint32_t tilingKey{0};
|
||||
uint64_t userWorkspaceSize{0};
|
||||
uint64_t s1Offset{0};
|
||||
uint64_t s2Offset{0};
|
||||
uint64_t s3Offset{0};
|
||||
uint64_t s4Offset{0};
|
||||
uint64_t s5Offset{0};
|
||||
|
||||
uint32_t numCore{0};
|
||||
uint32_t n{0};
|
||||
uint32_t perTaskNum{0};
|
||||
uint32_t resTaskNum{0};
|
||||
|
||||
PpMatmulTilingData mm1;
|
||||
PpMatmulTilingData mm2;
|
||||
PpMatmulTilingData mm3;
|
||||
// rms1
|
||||
uint32_t rmsNumCore1{0};
|
||||
uint32_t rmsNumCol1{0};
|
||||
uint32_t rmsNumRow1{0};
|
||||
uint32_t rmsQuantMin1{0};
|
||||
// rms2
|
||||
uint32_t rmsNumCore2{0};
|
||||
uint32_t rmsNumCol2{0};
|
||||
uint32_t rmsNumRow2{0};
|
||||
uint32_t rmsQuantMin2{0};
|
||||
|
||||
uint32_t hiddenSizeQ{0};
|
||||
uint32_t headNumQ{0};
|
||||
uint32_t headDim{0};
|
||||
uint32_t concatSize{0};
|
||||
uint32_t rotaryCoeff{0};
|
||||
uint32_t ntokens{0};
|
||||
uint32_t realCore{0};
|
||||
uint32_t nlCoreRun{0};
|
||||
uint32_t lCoreRun{0};
|
||||
uint32_t maxNPerLoopForUb{0};
|
||||
uint32_t preCoreLoopTime{0};
|
||||
uint32_t preCoreLoopNLast{0};
|
||||
uint32_t lastCoreLoopTime{0};
|
||||
uint32_t lastCoreLoopNLast{0};
|
||||
|
||||
// EinSumQuant
|
||||
uint32_t esqFrontCore{0};
|
||||
uint32_t esqTailCore{0};
|
||||
uint32_t esqFrontCoreBatch{0};
|
||||
uint32_t esqTailCoreBatch{0};
|
||||
uint32_t esqHeadNum{0};
|
||||
uint32_t esqColNum{0};
|
||||
uint32_t esqUbHeadLoop{0};
|
||||
uint32_t esqHeadPerLoop{0};
|
||||
uint32_t esqHeadTail{0};
|
||||
uint32_t esqColLoop{0};
|
||||
uint32_t esqColTail{0};
|
||||
};
|
||||
|
||||
#endif // MLAPREPROCESS_TILING_H
|
25
csrc/mla_preprocess/op_kernel/kernel/common.h
Normal file
25
csrc/mla_preprocess/op_kernel/kernel/common.h
Normal file
@ -0,0 +1,25 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_COMMON_H
|
||||
#define INCLUDE_COMMON_H
|
||||
|
||||
#define CONST_2 2
|
||||
|
||||
#define SET_FLAG(trigger, waiter, e) AscendC::SetFlag<AscendC::HardEvent::trigger##_##waiter>((e))
|
||||
#define WAIT_FLAG(trigger, waiter, e) AscendC::WaitFlag<AscendC::HardEvent::trigger##_##waiter>((e))
|
||||
#define PIPE_BARRIER(pipe) AscendC::PipeBarrier<PIPE_##pipe>()
|
||||
|
||||
#ifndef __force_inline__
|
||||
#define __force_inline__ inline __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
#endif
|
121
csrc/mla_preprocess/op_kernel/kernel/common_func.h
Normal file
121
csrc/mla_preprocess/op_kernel/kernel/common_func.h
Normal file
@ -0,0 +1,121 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef INCLUDE_COMMON_FUNC_H
|
||||
#define INCLUDE_COMMON_FUNC_H
|
||||
|
||||
#include <limits>
|
||||
#include <type_traits>
|
||||
|
||||
#ifdef __CCE_KT_TEST__
|
||||
#include "stub_def.h"
|
||||
#include "stub_fun.h"
|
||||
#else
|
||||
#include "kernel_macros.h"
|
||||
#endif
|
||||
|
||||
template <uint32_t ALIGN, typename T = uint32_t>
|
||||
inline __aicore__ T RoundUp(const T val)
|
||||
{
|
||||
static_assert(ALIGN != 0, "align must not be zero");
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
T align = ALIGN;
|
||||
if (val + align - 1 < val) {
|
||||
return val;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline __aicore__ T RoundUp(const T val, const T align)
|
||||
{
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
if (align == 0 || val + align - 1 < val) {
|
||||
return val;
|
||||
}
|
||||
return (val + align - 1) / align * align;
|
||||
}
|
||||
|
||||
template <uint32_t DIVISOR, typename T = uint32_t>
|
||||
inline __aicore__ T CeilDiv(const T dividend)
|
||||
{
|
||||
static_assert(DIVISOR != 0, "align must not be zero");
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
T divisor = DIVISOR;
|
||||
if (dividend + divisor - 1 < dividend) {
|
||||
return dividend;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
constexpr T T_MAX = std::numeric_limits<T>::max();
|
||||
|
||||
template <typename T>
|
||||
inline __aicore__ T CeilDiv(const T dividend, const T divisor)
|
||||
{
|
||||
static_assert(std::is_arithmetic<T>::value, "T must be an arithmetic type");
|
||||
if (divisor == 0 || dividend + divisor - 1 < dividend) {
|
||||
return T_MAX<T>;
|
||||
}
|
||||
return (dividend + divisor - 1) / divisor;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T Min(const T lhs, const T rhs)
|
||||
{
|
||||
return lhs < rhs ? lhs : rhs;
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint32_t BlockSize()
|
||||
{
|
||||
return 32 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint32_t MatrixSize()
|
||||
{
|
||||
return 512 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t BlockSizeRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>() * BlockSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t NumBlocksRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + BlockSize<Dtype>() - 1) / BlockSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t MatrixSizeRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>() * MatrixSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t NumMatrixsRoundUp(uint64_t num)
|
||||
{
|
||||
return (num + MatrixSize<Dtype>() - 1) / MatrixSize<Dtype>();
|
||||
}
|
||||
|
||||
template <typename Dtype>
|
||||
__aicore__ __attribute__((always_inline)) inline uint64_t L0HalfSize()
|
||||
{
|
||||
return 32 * 1024 / sizeof(Dtype);
|
||||
}
|
||||
|
||||
#endif
|
36
csrc/mla_preprocess/op_kernel/kernel/hardware.h
Normal file
36
csrc/mla_preprocess/op_kernel/kernel/hardware.h
Normal file
@ -0,0 +1,36 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_HARDWARE_H
|
||||
#define INCLUDE_HARDWARE_H
|
||||
|
||||
enum class ArchType { ASCEND_V220, ASCEND_V200, ASCEND_M200 };
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct HardwareInfo {
|
||||
static uint32_t const l2BW = 5;
|
||||
static uint32_t const hbmBW = 1;
|
||||
static uint32_t const supportMix = 0;
|
||||
static uint32_t const l1Size = 512 * 1024;
|
||||
static uint32_t const l0ASize = 64 * 1024;
|
||||
static uint32_t const l0BSize = 64 * 1024;
|
||||
static uint32_t const l0CSize = 128 * 1024;
|
||||
static uint32_t const l2Size = 192 * 1024 * 1024;
|
||||
static uint32_t const biasSize = 1024;
|
||||
static uint32_t const fixBufSize = 7 * 1024;
|
||||
static uint32_t const ubSize = 192 * 1024;
|
||||
static uint32_t const fractalSize = 512;
|
||||
static uint32_t const l1l0BlockSize = 32;
|
||||
static uint32_t const btBlockSize = 64;
|
||||
static uint32_t const fbBlockSize = 128;
|
||||
};
|
||||
|
||||
#endif
|
92
csrc/mla_preprocess/op_kernel/kernel/iterator.h
Normal file
92
csrc/mla_preprocess/op_kernel/kernel/iterator.h
Normal file
@ -0,0 +1,92 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_ITERTOR_H
|
||||
#define INCLUDE_ITERTOR_H
|
||||
|
||||
#include "common_func.h"
|
||||
#include "hardware.h"
|
||||
#include "kernel_operator.h"
|
||||
#include "layout.h"
|
||||
#include "mem.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// gm_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, DataFormat FormatInGM, DataFormat FormatInL1>
|
||||
struct gm_to_l1 {
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor, AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
|
||||
uint32_t dTileCeil, uint32_t dVal) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_a
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
|
||||
struct l1_to_l0_a {
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil, uint32_t kPartCeil, uint32_t mSrcStride, uint32_t kSrcStride,
|
||||
uint32_t mDstStride, uint32_t kDstStride) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_b
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose, DataFormat DFmtIn, DataFormat DFmtOut>
|
||||
struct l1_to_l0_b {
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor, AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil, uint32_t kPartCeil, uint32_t nSrcStride, uint32_t kSrcStride,
|
||||
uint32_t nDstStride, uint32_t kDstStride) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_gm
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, DataFormat OutFormatType, typename OutDataType, typename L0CDataType>
|
||||
struct l0c_to_gm {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<OutDataType> gmTensor, AscendC::LocalTensor<L0CDataType> l0cTensor,
|
||||
uint32_t mTileActual, uint32_t nTileActual, uint32_t mTileCeil, uint32_t nActual,
|
||||
uint8_t unitFlag = 0) {};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, DataFormat LayoutOut, typename ElementOut, typename ElementIn>
|
||||
struct l0c_to_l1 {
|
||||
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor, AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
AscendC::LocalTensor<uint64_t> deqTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t mTileCeil, uint32_t nActual) {};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_bt {
|
||||
__aicore__ l1_to_bt(uint64_t dst, const AscendC::LocalTensor<DataType> &src, uint16_t convControl, uint16_t nBurst,
|
||||
uint16_t lenBurst, uint16_t srcGap, uint16_t dstGap) {};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_fb {
|
||||
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst, AscendC::LocalTensor<DataType> &src, uint16_t burstNum,
|
||||
uint16_t burstLen, uint16_t srcGap, uint16_t dstGap) {};
|
||||
};
|
||||
|
||||
#include "iterators/gm_to_l1_iterator.inc"
|
||||
#include "iterators/gm_to_ub_iterator.inc"
|
||||
#include "iterators/l0c_to_gm_iterator.inc"
|
||||
#include "iterators/l0c_to_l1_iterator.inc"
|
||||
#include "iterators/l0c_to_ub_iterator.inc"
|
||||
#include "iterators/l1_to_bt_iterator.inc"
|
||||
#include "iterators/l1_to_fb_iterator.inc"
|
||||
#include "iterators/l1_to_l0_iterator.inc"
|
||||
#include "iterators/l1_to_ub_iterator.inc"
|
||||
#endif
|
@ -0,0 +1,162 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
// Partial specialization for V220, ND_in, ND_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ND> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
AscendC::DataCopy(l1Tensor, // dst
|
||||
gmTensor, // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
CeilDiv<BLOCK_SIZE>(nTileActual * dTileActual), // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for NZ_in, NZ_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
uint64_t srcStride = nVal - nTileCeil;
|
||||
if (srcStride < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor, // dst
|
||||
gmTensor, // src
|
||||
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
srcStride, // srcGap
|
||||
0)); // dstGap
|
||||
} else {
|
||||
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; i++) {
|
||||
uint64_t dstOffset = i * nTileCeil * BLOCK_SIZE;
|
||||
uint64_t srcOffset = i * nVal * BLOCK_SIZE;
|
||||
AscendC::DataCopy(l1Tensor[dstOffset], // dst
|
||||
gmTensor[srcOffset], // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, ND_in, ND_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
if (dVal < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
nTileActual, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
dVal, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
1, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
} else {
|
||||
for (uint32_t i = 0; i < nTileActual; i++) {
|
||||
AscendC::DataCopy(l1Tensor[i * BLOCK_SIZE],
|
||||
gmTensor[i * dVal],
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
1, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
0, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
0, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, ND_in, NZ_out
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct gm_to_l1<ArchTag, DataType, DataFormat::ND, DataFormat::ZN> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
|
||||
__aicore__ gm_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::GlobalTensor<DataType> gmTensor,
|
||||
uint32_t nTileActual,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t nVal,
|
||||
uint32_t dTileActual,
|
||||
uint32_t dTileCeil,
|
||||
uint32_t dVal)
|
||||
{
|
||||
if (dVal < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
nTileActual, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
dVal, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
1, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
} else {
|
||||
for (uint32_t i = 0; i < nTileActual; ++i) {
|
||||
AscendC::DataCopy(l1Tensor,
|
||||
gmTensor,
|
||||
AscendC::Nd2NzParams(1, // ndNum
|
||||
1, // nValue
|
||||
dTileActual, // dValue
|
||||
0, // srcNdMatrixStride, unused
|
||||
0, // srcDValue
|
||||
nTileCeil, // dstNzC0Stride
|
||||
0, // dstNzNStride
|
||||
0)); // dstNzMatrixStride, unused
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
@ -0,0 +1,89 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct gm_to_ub {
|
||||
__aicore__ inline gm_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct gm_to_ub_align {
|
||||
__aicore__ inline gm_to_ub_align(AscendC::LocalTensor<DType> dstTensor, AscendC::GlobalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
|
||||
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
|
||||
{
|
||||
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0),
|
||||
AscendC::DataCopyPadExtParams<DType>(false, leftPaddingNum, rightPaddingNum, 0));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct ub_to_ub {
|
||||
__aicore__ inline ub_to_ub(AscendC::LocalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType, DataFormat InDataFormat = DataFormat::ND,
|
||||
DataFormat OutDataFormat = DataFormat::ND>
|
||||
struct ub_to_gm {
|
||||
__aicore__ inline ub_to_gm(AscendC::GlobalTensor<DataType> dstTensor, AscendC::LocalTensor<DataType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint16_t lenBurst, uint16_t srcStride, uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(dstTensor, srcTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType> struct ub_to_gm<ArchTag, DataType, DataFormat::NZ, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ ub_to_gm(AscendC::GlobalTensor<DataType> gmTensor, AscendC::LocalTensor<DataType> ubTensor,
|
||||
uint32_t nTileActual, uint32_t nTileCeil, uint32_t nVal, uint32_t dTileActual,
|
||||
uint32_t dTileCeil, uint32_t dVal)
|
||||
{
|
||||
constexpr uint32_t STRIDE_LIMIT = 65536;
|
||||
uint64_t dstStride = nVal - nTileCeil;
|
||||
if (dstStride < STRIDE_LIMIT) {
|
||||
AscendC::DataCopy(gmTensor, // dst
|
||||
ubTensor, // src
|
||||
AscendC::DataCopyParams(dTileCeil / BLOCK_SIZE, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
dstStride)); // dstGap
|
||||
} else {
|
||||
for (uint64_t i = 0; i < dTileCeil / BLOCK_SIZE; ++i) {
|
||||
uint64_t dstOffset = i * nVal * BLOCK_SIZE;
|
||||
uint64_t srcOffset = i * nTileCeil * BLOCK_SIZE;
|
||||
AscendC::DataCopy(gmTensor[dstOffset], // dst
|
||||
ubTensor[srcOffset], // src
|
||||
AscendC::DataCopyParams(1, // nBurst
|
||||
nTileCeil, // lenBurst
|
||||
0, // srcGap
|
||||
0)); // dstGap
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DType> struct ub_to_gm_align {
|
||||
__aicore__ inline ub_to_gm_align(AscendC::GlobalTensor<DType> dstTensor, AscendC::LocalTensor<DType> srcTensor,
|
||||
uint8_t sid, uint16_t nBurst, uint32_t lenBurst, uint8_t leftPaddingNum,
|
||||
uint8_t rightPaddingNum, uint32_t srcGap, uint32_t dstGap)
|
||||
{
|
||||
AscendC::DataCopyPad(dstTensor, srcTensor, AscendC::DataCopyExtParams(nBurst, lenBurst, srcGap, dstGap, 0));
|
||||
};
|
||||
};
|
@ -0,0 +1,228 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
constexpr uint32_t BLOCK_NUM = 16;
|
||||
constexpr uint32_t BLOCK_SIZE_INT8 = 32;
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, float> {
|
||||
/**
|
||||
* @brief Copy data from L0C buffer to global memory, partial specialized for
|
||||
*
|
||||
* @param gmTensor the destination tensor on global memory, which is stored in ND format.
|
||||
* @param l0cTensor the source tensor on L0C buffer, which is stored in FRACTAL_NZ format.
|
||||
* @param mTileActual the m-direction size of the matrix in L0C buffer.
|
||||
* @param nTileActual the n-direction size of the matrix in L0C buffer.
|
||||
* @param srcStride the source stride between the adjacent fractal matrix along n-direction in unit of C0_SIZE.
|
||||
* @param dstStride the leading dimension of the destination matrix in unit of element.
|
||||
*/
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322F16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::F322F16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, half, int32_t> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::VDEQF16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<int32_t> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, __bf16, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<__bf16> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322BF16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<__bf16, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::F322BF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
// Partial specialization ND, float
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, float, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<float> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<float, float, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::NoQuant};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::NZ, half, float> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<half> gmTensor,
|
||||
AscendC::LocalTensor<float> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::F322F16;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<half, float, AscendC::CFG_NZ>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<float> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride - (nTileActual * sizeof(half) / sizeof(float)));
|
||||
intriParams.quantParams = {QuantMode_t::F322F16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l0c_to_gm<ArchType::ASCEND_V220, DataFormat::ND, int32_t, int32_t> {
|
||||
__aicore__ l0c_to_gm(AscendC::GlobalTensor<int32_t> gmTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t srcStride,
|
||||
uint32_t dstStride,
|
||||
uint8_t unitFlag = 0)
|
||||
{
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
auto intriParams = AscendC::FixpipeParamsV220(nTileActual, // nSize
|
||||
mTileActual, // mSize
|
||||
srcStride, // srcStride
|
||||
dstStride, // dstStride
|
||||
false); // enRelu
|
||||
|
||||
intriParams.quantPre = QuantMode_t::NoQuant;
|
||||
intriParams.unitFlag = unitFlag;
|
||||
AscendC::Fixpipe<int32_t, int32_t, AscendC::CFG_ROW_MAJOR>(gmTensor, l0cTensor, intriParams);
|
||||
#else
|
||||
AscendC::FixpipeParams<int32_t> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE_INT8),
|
||||
0,
|
||||
dstStride);
|
||||
intriParams.nz2ndParams = {true, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(gmTensor, l0cTensor, intriParams);
|
||||
#endif
|
||||
};
|
||||
};
|
@ -0,0 +1,42 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization ZN, half, int32_t
|
||||
template <ArchType ArchTag>
|
||||
struct l0c_to_l1<ArchTag, DataFormat::ZN, half, int32_t> {
|
||||
using ElementOut = half;
|
||||
using ElementIn = int32_t;
|
||||
__aicore__ l0c_to_l1(AscendC::LocalTensor<ElementOut> l1Tensor,
|
||||
AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
AscendC::LocalTensor<uint64_t> deqTensor,
|
||||
uint32_t mTileActual,
|
||||
uint32_t nTileActual,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t nActual)
|
||||
{
|
||||
constexpr uint32_t BLOCK_NUM = 16;
|
||||
constexpr uint32_t BLOCK_SIZE = 32;
|
||||
AscendC::FixpipeParams<ElementIn> intriParams(
|
||||
(nTileActual + BLOCK_NUM - 1) / AscendC::BLOCK_CUBE,
|
||||
static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE),
|
||||
0,
|
||||
mTileCeil - static_cast<uint16_t>(mTileActual * BLOCK_NUM * sizeof(float) / BLOCK_SIZE) *
|
||||
sizeof(ElementOut) / sizeof(ElementIn));
|
||||
intriParams.nz2ndParams = {false, 1, 0, 0, static_cast<uint16_t>(nTileActual)};
|
||||
intriParams.quantParams = {QuantMode_t::VDEQF16};
|
||||
AscendC::Fixpipe(l1Tensor, l0cTensor, deqTensor, intriParams);
|
||||
};
|
||||
};
|
@ -0,0 +1,71 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l0c_to_ub
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization ZN, half, int32_t
|
||||
template <ArchType ArchTag, typename ElementIn, typename ElementOut, bool MatrixMode = true>
|
||||
struct l0c_to_ub {
|
||||
__aicore__ l0c_to_ub(AscendC::LocalTensor<ElementOut> ubTensor,
|
||||
AscendC::LocalTensor<ElementIn> l0cTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
constexpr auto mode =
|
||||
MatrixMode ? AscendC::BlockMode::BLOCK_MODE_MATRIX : AscendC::BlockMode::BLOCK_MODE_VECTOR;
|
||||
AscendC::DataCopy(ubTensor,
|
||||
l0cTensor,
|
||||
AscendC::DataCopyParams(nBurst, // count
|
||||
lenBurst, // len
|
||||
srcStride, // srcStrideIn
|
||||
dstStride), // dstStrideIn
|
||||
AscendC::DataCopyEnhancedParams(mode, // blockModeIn
|
||||
AscendC::DeqScale::DEQ_NONE, // deqScaleIn
|
||||
0, // deqValueIn
|
||||
0, // sidStoreModeIn
|
||||
false, // isReluIn
|
||||
pad_t::PAD_NONE, // padModeIn
|
||||
0) // padValueIn
|
||||
);
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct l0c_to_ub<ArchTag, int32_t, half> {
|
||||
__aicore__ l0c_to_ub(AscendC::LocalTensor<half> ubTensor,
|
||||
AscendC::LocalTensor<int32_t> l0cTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(ubTensor,
|
||||
l0cTensor,
|
||||
AscendC::DataCopyParams(nBurst, // count
|
||||
lenBurst, // len
|
||||
srcStride, // srcStrideIn
|
||||
dstStride), // dstStrideIn
|
||||
AscendC::DataCopyEnhancedParams(AscendC::BlockMode::BLOCK_MODE_MATRIX, // blockModeIn
|
||||
AscendC::DeqScale::VDEQ16, // deqScaleIn
|
||||
0, // deqValueIn
|
||||
0, // sidStoreModeIn
|
||||
false, // isReluIn
|
||||
pad_t::PAD_NONE, // padModeIn
|
||||
0) // padValueIn
|
||||
);
|
||||
};
|
||||
};
|
@ -0,0 +1,39 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_bt
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for V220
|
||||
template <typename DataType>
|
||||
struct l1_to_bt<ArchType::ASCEND_V220, DataType> {
|
||||
__aicore__ l1_to_bt(uint64_t dst,
|
||||
const AscendC::LocalTensor<DataType> &src,
|
||||
uint16_t convControl,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcGap,
|
||||
uint16_t dstGap)
|
||||
{
|
||||
AscendC::LocalTensor<DataType> dstTensor;
|
||||
dstTensor.InitBuffer(dst, nBurst * lenBurst);
|
||||
dstTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
|
||||
AscendC::DataCopy(dstTensor,
|
||||
src,
|
||||
AscendC::DataCopyParams(nBurst, // nBurst
|
||||
lenBurst, // lenBurst
|
||||
srcGap, // srcGap
|
||||
dstGap)); // dstGap
|
||||
}
|
||||
};
|
@ -0,0 +1,36 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_fb
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for V220
|
||||
template <typename DataType>
|
||||
struct l1_to_fb<ArchType::ASCEND_V220, DataType> {
|
||||
__aicore__ l1_to_fb(AscendC::LocalTensor<DataType> &dst,
|
||||
AscendC::LocalTensor<DataType> &src,
|
||||
uint16_t burstNum,
|
||||
uint16_t burstLen,
|
||||
uint16_t srcGap,
|
||||
uint16_t dstGap)
|
||||
{
|
||||
dst.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2PIPE2GM);
|
||||
AscendC::DataCopy(dst,
|
||||
src,
|
||||
AscendC::DataCopyParams(burstNum, // nBurst
|
||||
burstLen, // lenBurst
|
||||
srcGap, // srcGap
|
||||
dstGap)); // dstGap);
|
||||
}
|
||||
};
|
@ -0,0 +1,310 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_a
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for vector
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose>
|
||||
struct l1_to_l0_a<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(l0Tensor,
|
||||
l1Tensor,
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
kPartCeil, // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // sid
|
||||
kDstStride, // dstStride
|
||||
IsTransPose, // transpose
|
||||
0)); // addrCalMode
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for no transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / BLOCK_NUM_PER_FRACTAL; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE], // dst
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // src
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(kPartCeil / BLOCK_SIZE), // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // sid
|
||||
kDstStride - 1, // dstStride
|
||||
false, // transpose
|
||||
0)); // addrCalMode
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, true, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / BLOCK_SIZE; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * mDstStride * FRACTAL_SIZE],
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0,
|
||||
static_cast<uint16_t>(kPartCeil / BLOCK_NUM_PER_FRACTAL),
|
||||
kSrcStride,
|
||||
0,
|
||||
kDstStride - 1,
|
||||
true,
|
||||
0));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_a<ArchTag, DataType, false, DataFormat::NZ, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
// 16 * 32
|
||||
static constexpr uint32_t ROW_BLOCK_SIZE = 16;
|
||||
static constexpr uint32_t COL_BLOCK_SIZE = 32 / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < mTileCeil / ROW_BLOCK_SIZE; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * ROW_BLOCK_SIZE * kPartCeil],
|
||||
l1Tensor[i * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0,
|
||||
static_cast<uint16_t>(kPartCeil / COL_BLOCK_SIZE),
|
||||
mTileCeil / ROW_BLOCK_SIZE,
|
||||
0,
|
||||
0,
|
||||
false,
|
||||
0));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l1_to_l0_a<ArchType::ASCEND_V220, int8_t, true, DataFormat::ZN, DataFormat::ZZ> {
|
||||
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 512
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize; // 16
|
||||
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
|
||||
__aicore__ l1_to_l0_a(AscendC::LocalTensor<int8_t> l0Tensor,
|
||||
AscendC::LocalTensor<int8_t> l1Tensor,
|
||||
uint32_t mTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t mSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t mDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint64_t i = 0; i < mTileCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
|
||||
AscendC::LoadDataWithTranspose(
|
||||
l0Tensor[i * mDstStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // dstLocalTensor
|
||||
l1Tensor[i * mSrcStride * FRACTAL_SIZE], // srcLocalTensor
|
||||
AscendC::LoadData2dTransposeParams(0, // baseIdx
|
||||
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(kPartCeil)), // repeat
|
||||
kSrcStride, // srcStride
|
||||
0, // dstGap
|
||||
mDstStride - 1)); // dstFracGap
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_l0_b
|
||||
/////////////////////////////////////////////////////
|
||||
|
||||
// Partial specialization for vector
|
||||
template <ArchType ArchTag, typename DataType, bool IsTransPose>
|
||||
struct l1_to_l0_b<ArchTag, DataType, IsTransPose, DataFormat::VECTOR, DataFormat::VECTOR> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(
|
||||
l0Tensor, l1Tensor, AscendC::LoadData2dParams(0, kPartCeil, kSrcStride, 0, kDstStride, IsTransPose, 0));
|
||||
};
|
||||
};
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct l1_to_l0_b<ArchTag, int8_t, true, DataFormat::NZ, DataFormat::ZN> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
using DataType = int8_t;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < nTileCeil / BLOCK_SIZE; i++) {
|
||||
AscendC::LoadDataWithTranspose(l0Tensor[i * kPartCeil * BLOCK_SIZE],
|
||||
l1Tensor[i * BLOCK_SIZE * BLOCK_SIZE],
|
||||
AscendC::LoadData2dTransposeParams(0, // startIndexIn
|
||||
kPartCeil / BLOCK_SIZE, // repeatTimesIn
|
||||
nTileCeil / BLOCK_SIZE, // srcStrideIn
|
||||
1, // dstGapIn
|
||||
0, // dstfracGapIn
|
||||
0) // addrModeIn
|
||||
);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for no transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_b<ArchTag, DataType, false, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint32_t i = 0; i < kPartCeil / BLOCK_NUM_PER_FRACTAL; i++) {
|
||||
AscendC::LoadData(l0Tensor[i * kDstStride * FRACTAL_SIZE],
|
||||
l1Tensor[i * kSrcStride * FRACTAL_SIZE],
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(nTileCeil / BLOCK_SIZE), // repeat
|
||||
nSrcStride, // srcStride
|
||||
0, // sid
|
||||
nDstStride - 1, // dstStride
|
||||
true, // transpose
|
||||
0)); // addrCalMode
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Partial specialization for transpose, not vector
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_l0_b<ArchTag, DataType, true, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchTag>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(DataType);
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(DataType);
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<DataType> l0Tensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
AscendC::LoadData(
|
||||
l0Tensor,
|
||||
l1Tensor,
|
||||
AscendC::LoadData2dParams(0, // baseIdx
|
||||
static_cast<uint16_t>(kPartCeil * nTileCeil / FRACTAL_SIZE), // repeat
|
||||
1, // srcStride
|
||||
0, // sid
|
||||
0, // dstStride
|
||||
false, // transpose
|
||||
0)); // addr_cal_mode_t
|
||||
};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct l1_to_l0_b<ArchType::ASCEND_V220, int8_t, false, DataFormat::ZN, DataFormat::NZ> {
|
||||
using HardwareParams = HardwareInfo<ArchType::ASCEND_V220>;
|
||||
static constexpr uint32_t BLOCK_SIZE = HardwareParams::l1l0BlockSize / sizeof(int8_t); // 32
|
||||
static constexpr uint32_t FRACTAL_SIZE = HardwareParams::fractalSize / sizeof(int8_t); // 16
|
||||
static constexpr uint32_t BLOCK_NUM_PER_FRACTAL = HardwareParams::fractalSize / HardwareParams::l1l0BlockSize;
|
||||
static constexpr uint32_t NUM_FRACTAL_PER_ITER = 2;
|
||||
|
||||
__aicore__ l1_to_l0_b(AscendC::LocalTensor<int8_t> l0Tensor,
|
||||
AscendC::LocalTensor<int8_t> l1Tensor,
|
||||
uint32_t nTileCeil,
|
||||
uint32_t kPartCeil,
|
||||
uint32_t nSrcStride,
|
||||
uint32_t kSrcStride,
|
||||
uint32_t nDstStride,
|
||||
uint32_t kDstStride)
|
||||
{
|
||||
for (uint64_t i = 0; i < kPartCeil / (BLOCK_NUM_PER_FRACTAL * NUM_FRACTAL_PER_ITER); ++i) {
|
||||
AscendC::LoadDataWithTranspose(
|
||||
l0Tensor[i * kDstStride * FRACTAL_SIZE], // dstLocalTensor
|
||||
l1Tensor[i * kSrcStride * FRACTAL_SIZE * NUM_FRACTAL_PER_ITER], // srcLocalTensor
|
||||
AscendC::LoadData2dTransposeParams(0, // baseIdx
|
||||
static_cast<uint16_t>(CeilDiv<BLOCK_SIZE>(nTileCeil)), // repeat
|
||||
nSrcStride / NUM_FRACTAL_PER_ITER, // srcStride
|
||||
1, // dstGap
|
||||
0)); // dstFracGap
|
||||
}
|
||||
};
|
||||
};
|
@ -0,0 +1,44 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#include "../iterator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// l1_to_ub
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct l1_to_ub {
|
||||
__aicore__ l1_to_ub(AscendC::LocalTensor<DataType> ubTensor,
|
||||
AscendC::LocalTensor<DataType> l1Tensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(ubTensor, l1Tensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// ub_to_l1
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct ub_to_l1 {
|
||||
__aicore__ ub_to_l1(AscendC::LocalTensor<DataType> l1Tensor,
|
||||
AscendC::LocalTensor<DataType> ubTensor,
|
||||
uint16_t nBurst,
|
||||
uint16_t lenBurst,
|
||||
uint16_t srcStride,
|
||||
uint16_t dstStride)
|
||||
{
|
||||
AscendC::DataCopy(l1Tensor, ubTensor, AscendC::DataCopyParams(nBurst, lenBurst, srcStride, dstStride));
|
||||
};
|
||||
};
|
395
csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h
Normal file
395
csrc/mla_preprocess/op_kernel/kernel/kernel_utils.h
Normal file
@ -0,0 +1,395 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
|
||||
#define ASCEND_OPS_UTILS_COMMON_KERNEL_KERNEL_UTILS_H
|
||||
#include "kernel_operator.h"
|
||||
|
||||
using AscendC::HardEvent;
|
||||
|
||||
__aicore__ inline uint32_t CeilDiv(uint32_t x, uint32_t y)
|
||||
{
|
||||
return y == 0 ? 0 : ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t RoundUp(uint32_t x, uint32_t y = 16)
|
||||
{
|
||||
return (x + y - 1) / y * y;
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t Min(uint32_t x, uint32_t y)
|
||||
{
|
||||
return x < y ? x : y;
|
||||
}
|
||||
|
||||
__aicore__ inline uint32_t Max(uint32_t x, uint32_t y)
|
||||
{
|
||||
return x > y ? x : y;
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyIn(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
|
||||
DataCopy(local, gm[offset], count);
|
||||
queue.EnQue(local);
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyOut(const AscendC::GlobalTensor<T> &gm, Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
|
||||
DataCopy(gm[offset], local, count);
|
||||
queue.FreeTensor(local);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CastFrom16To32(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &in,
|
||||
uint32_t count)
|
||||
{
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void CastFrom32To16(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
uint32_t count)
|
||||
{
|
||||
if constexpr (AscendC::IsSameType<T, half>::value) {
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE,
|
||||
count); // 310p cast fp32->half 只能用CAST_NONE,这里拉齐310p和910b
|
||||
} else { // bf16
|
||||
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
|
||||
}
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void CastFromF16ToI8(const AscendC::LocalTensor<int8_t> &out, const AscendC::LocalTensor<half> &in,
|
||||
half quantMin, uint32_t count)
|
||||
{
|
||||
Maxs(in, in, quantMin, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mins(in, in, (half)127, count); // 127: limit
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
|
||||
Cast(out, in, AscendC::RoundMode::CAST_RINT, count);
|
||||
#else
|
||||
Cast(out, in, AscendC::RoundMode::CAST_NONE, count);
|
||||
#endif
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void CopyInAndCastF32(const AscendC::LocalTensor<float> &out, const AscendC::GlobalTensor<T> &gm,
|
||||
Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
CopyIn(gm, queue, offset, count);
|
||||
AscendC::LocalTensor<T> local = queue.template DeQue<T>();
|
||||
Cast(out, local, AscendC::RoundMode::CAST_NONE, count);
|
||||
queue.FreeTensor(local);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T, typename Q>
|
||||
__aicore__ inline void Cast16AndCopyOut(const AscendC::LocalTensor<float> &in, const AscendC::GlobalTensor<T> &gm,
|
||||
Q &queue, uint64_t offset, uint32_t count)
|
||||
{
|
||||
AscendC::LocalTensor<T> local = queue.template AllocTensor<T>();
|
||||
CastFrom32To16(local, in, count);
|
||||
queue.EnQue(local);
|
||||
CopyOut(gm, queue, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline T ComputeSum(const AscendC::LocalTensor<T> &in, const AscendC::LocalTensor<T> &tmp,
|
||||
const AscendC::LocalTensor<T> &workLocal, uint32_t count)
|
||||
{
|
||||
#if __CCE_AICORE__ == 100
|
||||
float sum = 0;
|
||||
int64_t elementNumPerRep = AscendC::ONE_REPEAT_BYTE_SIZE / sizeof(T);
|
||||
AscendC::LocalTensor<T> src = in;
|
||||
while (count > elementNumPerRep) {
|
||||
int64_t repeatTimes = count / elementNumPerRep;
|
||||
int64_t tailCount = count % elementNumPerRep;
|
||||
int64_t bodyCount = repeatTimes * elementNumPerRep;
|
||||
if (repeatTimes > 0) {
|
||||
AscendC::AscendCUtils::SetMask<T>(elementNumPerRep);
|
||||
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)src.GetPhyAddr(), repeatTimes, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0); // PipeBarrier(PIPE_V)?
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
if (tailCount != 0) {
|
||||
AscendC::AscendCUtils::SetMask<T>(tailCount);
|
||||
vcadd((__ubuf__ T *)tmp[bodyCount].GetPhyAddr(), (__ubuf__ T *)src[bodyCount].GetPhyAddr(), 1, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
sum += tmp.GetValue(bodyCount);
|
||||
}
|
||||
|
||||
count = repeatTimes;
|
||||
src = tmp;
|
||||
}
|
||||
|
||||
if (count > 1) {
|
||||
AscendC::AscendCUtils::SetMask<T>(count);
|
||||
vcadd((__ubuf__ T *)tmp.GetPhyAddr(), (__ubuf__ T *)tmp.GetPhyAddr(), 1, 1, 1, 8);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
}
|
||||
|
||||
sum += tmp.GetValue(0);
|
||||
return sum;
|
||||
#else
|
||||
ReduceSum(tmp, in, workLocal, count);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
return tmp.GetValue(0);
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline float ComputeSliceSquareSum(const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<float> &tmp,
|
||||
const AscendC::LocalTensor<float> &workLocal, uint32_t count)
|
||||
{
|
||||
Mul(tmp, in, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
return ComputeSum(tmp, tmp, workLocal, count);
|
||||
}
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
|
||||
uint32_t precisionMode, uint32_t gemmaMode,
|
||||
const AscendC::LocalTensor<float> &tmp)
|
||||
{
|
||||
float value = 1.0;
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(tmp, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
if (precisionMode == 0) {
|
||||
CastFrom16To32(in, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if (gemmaMode == 1) {
|
||||
Adds(in, in, value, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
Mul(in, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(out, in, count);
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
CastFrom32To16(out, tmp, count);
|
||||
Mul(out, out, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t gemmaMode>
|
||||
__aicore__ inline void CastGAndIsGemmaMode(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<T> &gamma,
|
||||
uint32_t count)
|
||||
{
|
||||
Cast(out, gamma, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
float value = 1.0;
|
||||
if constexpr (gemmaMode == 1) {
|
||||
Adds(out, out, value, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, uint32_t precisionMode>
|
||||
__aicore__ inline void ComputeRmsNormFast(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<T> &gamma, uint32_t count,
|
||||
const AscendC::LocalTensor<float> &tmp,
|
||||
const AscendC::LocalTensor<float> &fp32_g)
|
||||
{
|
||||
float value = 1.0;
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(tmp, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (precisionMode == 0) {
|
||||
Mul(in, fp32_g, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(out, in, count);
|
||||
return;
|
||||
}
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
CastFrom32To16(out, tmp, count);
|
||||
Mul(out, out, gamma, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <bool WITH_BETA = true>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
float rms, const AscendC::LocalTensor<half> &gamma,
|
||||
const AscendC::LocalTensor<half> &beta, const AscendC::LocalTensor<float> &tmp,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(tmp, rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Div(out, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom16To32(tmp, gamma, count);
|
||||
Mul(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
if constexpr (WITH_BETA) {
|
||||
CastFrom16To32(tmp, beta, count);
|
||||
Add(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeRmsNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
float reciprocal_of_rms, const AscendC::LocalTensor<T> &gamma,
|
||||
const AscendC::LocalTensor<float> &tmp, const AscendC::LocalTensor<T> &res_out,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(tmp, reciprocal_of_rms, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom16To32(tmp, gamma, count);
|
||||
Mul(out, out, tmp, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(res_out, out, count);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeResidualAdd(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in,
|
||||
const AscendC::LocalTensor<T> &resIn, uint32_t count)
|
||||
{
|
||||
Add(out, in, resIn, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeMean(const AscendC::LocalTensor<T> &out, const AscendC::LocalTensor<T> &in, T aveNum,
|
||||
uint32_t count)
|
||||
{
|
||||
Duplicate(out, aveNum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
T sum = ComputeSum(out, out, out, count);
|
||||
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
Duplicate(out, sum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline void ComputeLayerNorm(const AscendC::LocalTensor<float> &out, const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<float> &mean, float eps, float aveNum,
|
||||
const AscendC::LocalTensor<T> &gamma, const AscendC::LocalTensor<T> &beta,
|
||||
uint32_t count)
|
||||
{
|
||||
Sub(in, in, mean, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, in, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Muls(out, out, aveNum, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ReduceSum(out, out, out, count);
|
||||
AscendC::SetFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::V_S>(EVENT_ID0);
|
||||
float var = out.GetValue(0);
|
||||
AscendC::SetFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
AscendC::WaitFlag<HardEvent::S_V>(EVENT_ID0);
|
||||
Duplicate(out, var, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(out, out, eps, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Sqrt(out, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
Div(out, in, out, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
|
||||
Cast(in, gamma, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Mul(out, out, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Cast(in, beta, AscendC::RoundMode::CAST_NONE, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Add(out, out, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFp16ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<half> &in, const AscendC::LocalTensor<half> &tmp,
|
||||
half scale, half offset, half quantMin, uint32_t count)
|
||||
{
|
||||
Muls(tmp, in, scale, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(tmp, tmp, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFromF16ToI8(out, tmp, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<half> &tmp, half scale, half offset,
|
||||
half quantMin, uint32_t count)
|
||||
{
|
||||
CastFrom32To16(tmp, in, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
ComputeFp16ToI8Quant(out, tmp, tmp, scale, offset, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void ComputeHighPrecisionFp32ToI8Quant(const AscendC::LocalTensor<int8_t> &out,
|
||||
const AscendC::LocalTensor<float> &in,
|
||||
const AscendC::LocalTensor<half> &tmp, float scale,
|
||||
float offset, half quantMin, uint32_t count)
|
||||
{
|
||||
Muls(in, in, scale, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
Adds(in, in, offset, count);
|
||||
AscendC::PipeBarrier<PIPE_V>();
|
||||
CastFrom32To16(tmp, in, count);
|
||||
CastFromF16ToI8(out, tmp, quantMin, count);
|
||||
}
|
||||
|
||||
__aicore__ inline void CopyGmTilingToUb(__ubuf__ uint8_t *&tilingInUb, const __gm__ uint8_t *tilingInGm,
|
||||
size_t tilingSize, AscendC::TPipe *pipe)
|
||||
{
|
||||
uint32_t roundTilingSize = RoundUp(tilingSize, 32);
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> tilingBuf;
|
||||
AscendC::GlobalTensor<uint8_t> tilingGm;
|
||||
|
||||
tilingGm.SetGlobalBuffer((__gm__ uint8_t *)tilingInGm);
|
||||
pipe->InitBuffer(tilingBuf, roundTilingSize);
|
||||
|
||||
AscendC::LocalTensor<uint8_t> tilingUb = tilingBuf.Get<uint8_t>();
|
||||
AscendC::DataCopy(tilingUb, tilingGm, roundTilingSize);
|
||||
|
||||
tilingInUb = (__ubuf__ uint8_t *)tilingUb.GetPhyAddr();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__aicore__ inline uint32_t GetReduceSumWorkLocalSize(uint32_t sliceSize)
|
||||
{
|
||||
uint32_t elementsPerBlock = 32 / sizeof(T);
|
||||
uint32_t elementsPerRepeat = 256 / sizeof(T);
|
||||
|
||||
uint32_t firstMaxRepeat = sliceSize < elementsPerRepeat ? 1u : (sliceSize / elementsPerRepeat);
|
||||
uint32_t iter1OutputCount = firstMaxRepeat;
|
||||
uint32_t iter1AlignEnd = RoundUp(iter1OutputCount, elementsPerBlock);
|
||||
return iter1AlignEnd;
|
||||
}
|
||||
|
||||
#endif
|
18
csrc/mla_preprocess/op_kernel/kernel/layout.h
Normal file
18
csrc/mla_preprocess/op_kernel/kernel/layout.h
Normal file
@ -0,0 +1,18 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
#ifndef INCLUDE_LAYOUT_H
|
||||
#define INCLUDE_LAYOUT_H
|
||||
|
||||
enum class DataFormat { ND = 0, NZ, ZN, ZZ, NN, VECTOR };
|
||||
|
||||
#endif
|
82
csrc/mla_preprocess/op_kernel/kernel/mem.h
Normal file
82
csrc/mla_preprocess/op_kernel/kernel/mem.h
Normal file
@ -0,0 +1,82 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_MEM_H
|
||||
#define INCLUDE_MEM_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_event.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
enum class BufferType { ASCEND_UB, ASCEND_CB, ASCEND_L0A, ASCEND_L0B, ASCEND_L0C, ASCEND_MAX };
|
||||
|
||||
template <BufferType BufferType_>
|
||||
__aicore__ constexpr AscendC::TPosition GetPosition()
|
||||
{
|
||||
if constexpr (BufferType_ == BufferType::ASCEND_UB) {
|
||||
return AscendC::TPosition::VECIN;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_CB) {
|
||||
return AscendC::TPosition::A1;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0A) {
|
||||
return AscendC::TPosition::A2;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0B) {
|
||||
return AscendC::TPosition::B2;
|
||||
} else if constexpr (BufferType_ == BufferType::ASCEND_L0C) {
|
||||
return AscendC::TPosition::CO1;
|
||||
}
|
||||
return AscendC::TPosition::GM;
|
||||
}
|
||||
|
||||
template <ArchType ArchTag>
|
||||
struct AsdopsBuffer {
|
||||
public:
|
||||
__aicore__ AsdopsBuffer()
|
||||
{
|
||||
constexpr uint32_t bufferSize[(uint32_t)BufferType::ASCEND_MAX] = {
|
||||
HardwareInfo<ArchTag>::ubSize, HardwareInfo<ArchTag>::l1Size, HardwareInfo<ArchTag>::l0ASize,
|
||||
HardwareInfo<ArchTag>::l0BSize, HardwareInfo<ArchTag>::l0CSize};
|
||||
#ifdef __DAV_C220_VEC__
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
#elif defined(__DAV_C220_CUBE__)
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
|
||||
#else
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_UB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_UB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::VECIN);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_CB]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_CB].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A1);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0A]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0A].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::A2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0B]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0B].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::B2);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].InitBuffer(0, bufferSize[(uint32_t)BufferType::ASCEND_L0C]);
|
||||
tensor[(uint32_t)BufferType::ASCEND_L0C].address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::CO1);
|
||||
#endif
|
||||
};
|
||||
|
||||
template <BufferType BufferType_, typename DstDataType = half>
|
||||
__aicore__ AscendC::LocalTensor<DstDataType> GetBuffer(const uint32_t offset) const
|
||||
{
|
||||
return tensor[(uint32_t)BufferType_][offset].template ReinterpretCast<DstDataType>();
|
||||
}
|
||||
|
||||
public:
|
||||
AscendC::LocalTensor<uint8_t> tensor[(uint32_t)BufferType::ASCEND_MAX];
|
||||
};
|
||||
|
||||
#endif
|
67
csrc/mla_preprocess/op_kernel/kernel/mma.h
Normal file
67
csrc/mla_preprocess/op_kernel/kernel/mma.h
Normal file
@ -0,0 +1,67 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_MMA_H
|
||||
#define INCLUDE_MMA_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
template <ArchType ArchTag, typename ElementA, typename ElementB, typename AccDTypeC, bool IsTransposeA>
|
||||
struct mmad {
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
|
||||
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
|
||||
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0) {};
|
||||
};
|
||||
|
||||
// Partial specialization for V220, int8_t, not_vector_A, not TransposeA
|
||||
template <ArchType ArchTag, typename AccDTypeC, typename ElementA, typename ElementB>
|
||||
struct mmad<ArchTag, ElementA, ElementB, AccDTypeC, false> {
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint32_t mTileActual, uint32_t nTileActual,
|
||||
uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::Mmad(l0cTensor, // C
|
||||
l0aTensor, // A
|
||||
l0bTensor, // B
|
||||
AscendC::MmadParams(mTileActual, // m
|
||||
nTileActual, // n
|
||||
kPartActual, // k
|
||||
unitFlag, // unitFlag
|
||||
false, // cmatrixSource
|
||||
initC)); // cmatrixInitVal
|
||||
};
|
||||
|
||||
__aicore__ mmad(AscendC::LocalTensor<AccDTypeC> l0cTensor, AscendC::LocalTensor<ElementA> l0aTensor,
|
||||
AscendC::LocalTensor<ElementB> l0bTensor, uint64_t biasBt, uint32_t mTileActual,
|
||||
uint32_t nTileActual, uint32_t kPartActual, bool initC, uint8_t unitFlag = 0)
|
||||
{
|
||||
AscendC::LocalTensor<AccDTypeC> biasTensor;
|
||||
biasTensor.InitBuffer(biasBt, nTileActual);
|
||||
biasTensor.address_.logicPos = static_cast<uint8_t>(AscendC::TPosition::C2);
|
||||
AscendC::Mmad(l0cTensor, // C
|
||||
l0aTensor, // A
|
||||
l0bTensor, // B
|
||||
biasTensor, // bt
|
||||
AscendC::MmadParams(mTileActual, // m
|
||||
nTileActual, // n
|
||||
kPartActual, // k
|
||||
unitFlag, // unitFlag
|
||||
true, // cmatrixSource
|
||||
false)); // cmatrixInitVal
|
||||
};
|
||||
};
|
||||
|
||||
#endif
|
38
csrc/mla_preprocess/op_kernel/kernel/set_fpc.h
Normal file
38
csrc/mla_preprocess/op_kernel/kernel/set_fpc.h
Normal file
@ -0,0 +1,38 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_SET_FPC_H
|
||||
#define INCLUDE_SET_FPC_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_tensor.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// SetQuantPreAddr
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DataType>
|
||||
struct SetQuantPreAddr {
|
||||
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor) {};
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
struct SetQuantPreAddr<ArchType::ASCEND_V220, DataType> {
|
||||
static constexpr uint32_t QUANT_PRE_ADDR_MASK = 0xffff;
|
||||
static constexpr uint32_t USELESS_BIT_NUM = 7;
|
||||
static constexpr uint32_t QUANT_PRE_BIT_POS_IN_FPC = 8;
|
||||
|
||||
__aicore__ SetQuantPreAddr(AscendC::LocalTensor<DataType> quantPreTensor)
|
||||
{
|
||||
uint64_t quantPreAddr = (uint64_t)(__fbuf__ uint64_t *)quantPreTensor.GetPhyAddr();
|
||||
AscendC::SetFixPipeConfigImpl(quantPreTensor);
|
||||
};
|
||||
};
|
||||
#endif
|
274
csrc/mla_preprocess/op_kernel/kernel/simd.h
Normal file
274
csrc/mla_preprocess/op_kernel/kernel/simd.h
Normal file
@ -0,0 +1,274 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_SIMD_H
|
||||
#define INCLUDE_SIMD_H
|
||||
|
||||
#include "hardware.h"
|
||||
#include "kernel_operator.h"
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcgadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cgadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
|
||||
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
|
||||
{
|
||||
AscendC::BlockReduceSum<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void add_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Add<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vadds
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void adds_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, DType scalarValue,
|
||||
uint8_t repeat, uint8_t dstBlockStride, uint8_t srcBlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Adds<DType, false>(
|
||||
dst, src, scalarValue, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcadd
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cadd_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::RepeatReduceSum<DType, false>(dst, src, repeat, 0, 0, srcBlockStride, dstRepeatStride, srcRepeatStride);
|
||||
}
|
||||
/////////////////////////////////////////////////////
|
||||
// vbrcb
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void brcb_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint16_t dstBlockStride,
|
||||
uint16_t dstRepeatStride, uint8_t repeat)
|
||||
{
|
||||
AscendC::Brcb(dst, src, repeat, AscendC::BrcbRepeatParams(dstBlockStride, dstRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType, AscendC::ReduceOrder OrderType>
|
||||
__aicore__ inline void cmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstRepeatStride, uint16_t srcBlockStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
#if defined(__DAV_C220_VEC__)
|
||||
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
|
||||
srcRepeatStride, OrderType);
|
||||
#else
|
||||
AscendC::WholeReduceMax<DType, false>(dst, src, (int32_t)0, repeat, dstRepeatStride, srcBlockStride,
|
||||
srcRepeatStride);
|
||||
#endif
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vconv
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
|
||||
__aicore__ inline void conv_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
if constexpr (std::is_same<DTypeIn, float>::value && std::is_same<DTypeOut, __bf16>::value) {
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
} else {
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_NONE, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vconv_f322bf16r
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DTypeIn, typename DTypeOut>
|
||||
__aicore__ inline void convr_v(AscendC::LocalTensor<DTypeOut> dst, AscendC::LocalTensor<DTypeIn> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Cast<DTypeOut, DTypeIn, false>(
|
||||
dst, src, AscendC::RoundMode::CAST_RINT, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vdiv
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void div_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Div<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vexp
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void exp_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Exp<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void max_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Max<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmul
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void mul_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Mul<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmuls
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void muls_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Muls<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vsub
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void sub_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0,
|
||||
AscendC::LocalTensor<DType> src1, uint8_t repeat, uint8_t dstBlockStride,
|
||||
uint8_t src0BlockStride, uint8_t src1BlockStride, uint8_t dstRepeatStride,
|
||||
uint8_t src0RepeatStride, uint8_t src1RepeatStride)
|
||||
{
|
||||
AscendC::Sub<DType, false>(dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::BinaryRepeatParams(dstBlockStride, src0BlockStride, src1BlockStride,
|
||||
dstRepeatStride, src0RepeatStride, src1RepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmaxs
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void maxs_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Maxs<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vmins
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void mins_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src0, DType src1,
|
||||
uint8_t repeat, uint16_t dstBlockStride, uint16_t srcBlockStride,
|
||||
uint16_t dstRepeatStride, uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Mins<DType, false>(
|
||||
dst, src0, src1, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vsqrt
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void sqrt_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Sqrt<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vln
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void ln_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, uint8_t repeat,
|
||||
uint16_t dstBlockStride, uint16_t srcBlockStride, uint16_t dstRepeatStride,
|
||||
uint16_t srcRepeatStride)
|
||||
{
|
||||
AscendC::Ln<DType, false>(
|
||||
dst, src, (uint64_t)0, repeat,
|
||||
AscendC::UnaryRepeatParams(dstBlockStride, srcBlockStride, dstRepeatStride, srcRepeatStride));
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vtranspose
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void tranpose_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src)
|
||||
{
|
||||
AscendC::Transpose(dst, src);
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// vcgmax
|
||||
/////////////////////////////////////////////////////
|
||||
template <ArchType ArchTag, typename DType>
|
||||
__aicore__ inline void cgmax_v(AscendC::LocalTensor<DType> dst, AscendC::LocalTensor<DType> src, const int32_t repeat,
|
||||
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride)
|
||||
{
|
||||
AscendC::BlockReduceMax<DType, false>(dst, src, repeat, 0, dstRepStride, srcBlkStride, srcRepStride);
|
||||
}
|
||||
#endif
|
69
csrc/mla_preprocess/op_kernel/kernel/utils.h
Normal file
69
csrc/mla_preprocess/op_kernel/kernel/utils.h
Normal file
@ -0,0 +1,69 @@
|
||||
/* Adapted from
|
||||
* https://gitee.com/ascend/ascend-transformer-boost.git
|
||||
*
|
||||
* Copyright (c) 2024 Huawei Technologies Co., Ltd.
|
||||
* This file is a part of the CANN Open Software.
|
||||
* Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
#ifndef INCLUDE_UTILS_H
|
||||
#define INCLUDE_UTILS_H
|
||||
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void CreateCaMatrix(const AscendC::LocalTensor<IN_DTYPE> &dst, const uint16_t repeats,
|
||||
const uint16_t blockNum, const uint16_t dstGap, const IN_DTYPE initValue)
|
||||
{
|
||||
AscendC::InitConstValue<IN_DTYPE>(dst,
|
||||
AscendC::InitConstValueParams<IN_DTYPE>(repeats, blockNum, dstGap, initValue));
|
||||
}
|
||||
__aicore__ inline void SetFftsBaseAddr(uint64_t config)
|
||||
{
|
||||
AscendC::SetSyncBaseAddr(config);
|
||||
}
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void SetPadding(IN_DTYPE padValue)
|
||||
{
|
||||
AscendC::SetLoadDataPaddingValue<IN_DTYPE>(padValue);
|
||||
}
|
||||
__aicore__ inline void SetAtomicnone()
|
||||
{
|
||||
AscendC::SetAtomicNone();
|
||||
}
|
||||
__aicore__ inline void SetMasknorm()
|
||||
{
|
||||
#if __CCE_AICORE__ == 100
|
||||
return;
|
||||
#endif
|
||||
AscendC::SetMaskNorm();
|
||||
}
|
||||
__aicore__ inline void SetNdpara(uint16_t ndNum, uint16_t srcNdStride, uint16_t dstNdStride)
|
||||
{
|
||||
AscendC::SetFixpipeNz2ndFlag(ndNum, srcNdStride, dstNdStride);
|
||||
}
|
||||
template <typename IN_DTYPE>
|
||||
__aicore__ inline void SetVectorMask(const uint64_t maskHigh, const uint64_t maskLow)
|
||||
{
|
||||
AscendC::SetVectorMask<IN_DTYPE>(maskHigh, maskLow);
|
||||
}
|
||||
__aicore__ inline int64_t GetSubBlockidx()
|
||||
{
|
||||
return AscendC::GetSubBlockIdx();
|
||||
}
|
||||
__aicore__ inline void WaitFlagDev(uint16_t flagId)
|
||||
{
|
||||
AscendC::WaitEvent(flagId);
|
||||
}
|
||||
template <pipe_t pipe, uint8_t mode>
|
||||
__aicore__ inline void FftsCrossCoreSync(uint16_t flagId)
|
||||
{
|
||||
AscendC::CrossCoreSetFlag<mode, pipe>(flagId);
|
||||
}
|
||||
template <typename IN_DTYPE, bool setRelu = false>
|
||||
__aicore__ inline void SetFpc(const AscendC::LocalTensor<IN_DTYPE> &preTensor, bool isUnitFlag = false)
|
||||
{
|
||||
AscendC::SetFixPipeConfig<IN_DTYPE, setRelu>(preTensor, isUnitFlag);
|
||||
}
|
||||
#endif
|
114
csrc/mla_preprocess/op_kernel/mla_preprocess.h
Normal file
114
csrc/mla_preprocess/op_kernel/mla_preprocess.h
Normal file
@ -0,0 +1,114 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#ifndef __MLA_PREPROCESS_H__
|
||||
#define __MLA_PREPROCESS_H__
|
||||
|
||||
// sync
|
||||
constexpr int32_t QUANT1 = 1;
|
||||
constexpr int32_t MM1 = 2;
|
||||
constexpr int32_t MM1QUANT = 3;
|
||||
constexpr int32_t RMSNORMQUANT2 = 4;
|
||||
constexpr int32_t MM2 = 5;
|
||||
constexpr int32_t MM2QUANT = 6;
|
||||
constexpr int32_t BMM3 = 7;
|
||||
constexpr int32_t BMM3SPLIT = 8;
|
||||
constexpr int32_t MM2OUT = 9;
|
||||
constexpr int32_t EINSUMOUT = 11;
|
||||
constexpr int32_t EINSUMQUANT = 12;
|
||||
|
||||
// ropeConcat
|
||||
constexpr uint32_t ELE_NUM_FP16 = 16; // nums of fp16 elements in one block
|
||||
constexpr uint32_t ELE_NUM_FP32 = 8; // nums of fp32 elements in one block
|
||||
constexpr uint8_t DEFAULT_REPEAT_STRIDE = 8; // stride, 8 * 32 = 256
|
||||
|
||||
// rmsNormQuant
|
||||
constexpr int32_t NUM_PER_REP_FP32 = 64; // ONE_REPEAT_BYTE_SIZE / sizeof(float);
|
||||
constexpr float ZERO = 0;
|
||||
constexpr uint32_t BUF_FACTOR = 3; // 1(g) + 1(sqx) + 1(sum) = 3
|
||||
constexpr uint32_t OFFSET_GAMMA = 0; // the offset of gamma is 0
|
||||
constexpr uint32_t OFFSET_SQX = 1; // the offset of sqx is 1
|
||||
constexpr uint32_t OFFSET_SUM = 2; // the offset of sum is 2
|
||||
constexpr uint32_t OFFSET_WORKSPACE = 3; // the offset of workspace is 3
|
||||
constexpr uint32_t REPEAT_TIME_256 = 256; // 128 default stride
|
||||
constexpr uint32_t REPEAT_TIME_128 = 128; // 128 default stride
|
||||
constexpr uint32_t REPEAT_TIME_64 = 64; // 64 default stride
|
||||
|
||||
constexpr uint8_t CACHE_MODE_KVCACHE = 0; // single input single output
|
||||
constexpr uint8_t CACHE_MODE_KROPE_CTKV = 1; // double in and double out
|
||||
constexpr uint8_t CACHE_MODE_INT8_NZCACHE = 2; // high performance KV NZ format/quant int8
|
||||
constexpr uint8_t CACHE_MODE_NZCACHE = 3;
|
||||
|
||||
// pp matmul
|
||||
constexpr uint32_t HIDDTEN_STATE = 7168;
|
||||
constexpr uint32_t FLOAT_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_BLOCK_SIZE = 64;
|
||||
constexpr uint32_t HALF_VECTOR_SIZE = 64;
|
||||
constexpr uint32_t MM1_OUT_SIZE = 2112;
|
||||
constexpr uint32_t SPLIT_SIZE_ONE = 576;
|
||||
constexpr uint32_t SPLIT_SIZE_TWO = 1536;
|
||||
constexpr uint32_t SPLIT_RMSNRORM_SIZE_ONE = 512;
|
||||
constexpr uint32_t SPLIT_RMSNRORM_SIZE_TWO = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_ONE = 64;
|
||||
constexpr uint32_t ROPE_SPLIT_SIZE_TWO = 128;
|
||||
|
||||
constexpr uint32_t MMSIZE1 = 128 * 192; // 24576
|
||||
constexpr uint32_t MMSIZE2 = 64 * 128; // 8192
|
||||
|
||||
constexpr uint64_t L0_PINGPONG_BUFFER_LEN = 32768; // 32 KB
|
||||
constexpr uint64_t L1_PINGPONG_BUFFER_LEN = 262144; // 256 KB
|
||||
constexpr uint64_t BLOCK_SIZE_16 = 16;
|
||||
constexpr uint64_t BLOCK_SIZE_32 = 32;
|
||||
constexpr uint64_t CUBE_MATRIX_SIZE_512 = 16 * 32; // 16 * 23
|
||||
constexpr uint64_t FB_BUFF_SIZE = 1024 * 7;
|
||||
constexpr uint64_t SCALE_L1_LEN = 4096;
|
||||
constexpr uint64_t BIAS_L1_LEN = 2048;
|
||||
|
||||
constexpr uint64_t CONST_0 = 0;
|
||||
constexpr uint64_t CONST_4 = 4;
|
||||
constexpr uint64_t CONST_8 = 8;
|
||||
constexpr uint64_t CONST_32 = 32;
|
||||
constexpr uint64_t CONST_64 = 64;
|
||||
constexpr uint64_t CONST_128 = 128;
|
||||
|
||||
// ropeConcat
|
||||
constexpr uint32_t ROPE_CONCAT_NUM_BUFFER = 2;
|
||||
|
||||
// rmsNormQuant
|
||||
constexpr uint32_t OFFSET_ABS = 3; // the offset of abs is 3
|
||||
constexpr uint32_t OFFSET_WORKSPACE_BF16 = 4; // the offset of workspace is 4
|
||||
|
||||
// sync bf16
|
||||
constexpr int32_t AIC_MM1_START = 2;
|
||||
constexpr int32_t AIC_MM3_START = 3;
|
||||
constexpr int32_t AIC_MM2_START = 6;
|
||||
constexpr int32_t MMAIC = 7;
|
||||
constexpr int32_t MMAIV = 8;
|
||||
|
||||
constexpr uint32_t MAX_HW_SYNC_COUNTER = 5;
|
||||
constexpr uint32_t SYNC_MODE = 2;
|
||||
|
||||
// TilingKey
|
||||
constexpr uint32_t KEY_FP16_CACHEMODE_0_QUANTMODE_0 = 0;
|
||||
constexpr uint32_t KEY_FP16_CACHEMODE_1_QUANTMODE_0 = 1;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_0_QUANTMODE_0 = 256;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_1_QUANTMODE_0 = 257;
|
||||
constexpr uint32_t KEY_BF16_CACHEMODE_3_QUANTMODE_0 = 259;
|
||||
|
||||
enum class QuantMode : int32_t {
|
||||
PER_TENSOR_ASYMM_QUANT = 0,
|
||||
PER_TOKEN_SYMM_QUANT,
|
||||
PER_TOKEN_ASYMM_QUANT,
|
||||
NO_QUANT,
|
||||
};
|
||||
|
||||
#endif // __MLA_PREPROCESS_H__
|
299
csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Normal file
299
csrc/mla_preprocess/op_kernel/mla_preprocess_kernel.cpp
Normal file
@ -0,0 +1,299 @@
|
||||
// Adapted from
|
||||
// https://gitee.com/ascend/ascend-transformer-boost
|
||||
//
|
||||
// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
|
||||
// This file is a part of the CANN Open Software.
|
||||
// Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
|
||||
// Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
// THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
// INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
// See LICENSE in the root of the software repository for the full text of the License.
|
||||
//
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "../../kernels/types.h"
|
||||
|
||||
#include "mla_preprocess_mix_fp16.hpp"
|
||||
#include "mla_preprocess_mix_bf16.hpp"
|
||||
|
||||
#include "../op_host/tiling/mla_preprocess_tiling.h"
|
||||
|
||||
extern "C" __global__ __aicore__ void mla_preprocess(
|
||||
GM_ADDR hiddenState, GM_ADDR gamma1, GM_ADDR beta1, GM_ADDR quantScale1, GM_ADDR quantOffset1, GM_ADDR wdqkv,
|
||||
GM_ADDR bias1, GM_ADDR gamma2, GM_ADDR beta2, GM_ADDR quantScale2, GM_ADDR quantOffset2, GM_ADDR gamma3,
|
||||
GM_ADDR sin1, GM_ADDR cos1, GM_ADDR sin2, GM_ADDR cos2, GM_ADDR keycache, GM_ADDR slotMapping, GM_ADDR wuq,
|
||||
GM_ADDR bias2, GM_ADDR wuk, GM_ADDR descale1, GM_ADDR descale2, GM_ADDR ctkvScale, GM_ADDR qnopeScale, GM_ADDR q,
|
||||
GM_ADDR keycacheOut, GM_ADDR q2, GM_ADDR keycacheOut2, GM_ADDR workspace, GM_ADDR tiling)
|
||||
{
|
||||
#if defined(__CCE_KT_TEST__) || (__CCE_AICORE__ == 220)
|
||||
PRELOAD(2);
|
||||
#endif
|
||||
|
||||
SetAtomicnone();
|
||||
SetMasknorm();
|
||||
#ifdef __DAV_C220_CUBE__
|
||||
SetPadding<uint64_t>((uint64_t)0);
|
||||
SetNdpara(1, 0, 0);
|
||||
#endif
|
||||
|
||||
MlaTilingData mlaTilingData;
|
||||
__gm__ MlaTilingData *tilingData = reinterpret_cast<__gm__ MlaTilingData *>(tiling);
|
||||
|
||||
mlaTilingData.tilingKey = tilingData->tilingKey;
|
||||
mlaTilingData.n = tilingData->n;
|
||||
|
||||
mlaTilingData.mm1.numBatch = tilingData->mm1.numBatch;
|
||||
mlaTilingData.mm1.m = tilingData->mm1.m;
|
||||
mlaTilingData.mm1.k = tilingData->mm1.k;
|
||||
mlaTilingData.mm1.n = tilingData->mm1.n;
|
||||
mlaTilingData.mm1.m0 = tilingData->mm1.m0;
|
||||
mlaTilingData.mm1.k0 = tilingData->mm1.k0;
|
||||
mlaTilingData.mm1.n0 = tilingData->mm1.n0;
|
||||
mlaTilingData.mm1.mLoop = tilingData->mm1.mLoop;
|
||||
mlaTilingData.mm1.kLoop = tilingData->mm1.kLoop;
|
||||
mlaTilingData.mm1.nLoop = tilingData->mm1.nLoop;
|
||||
mlaTilingData.mm1.coreLoop = tilingData->mm1.coreLoop;
|
||||
mlaTilingData.mm1.swizzleCount = tilingData->mm1.swizzleCount;
|
||||
mlaTilingData.mm1.enShuffleK = tilingData->mm1.enShuffleK;
|
||||
mlaTilingData.mm1.blockDim = tilingData->mm1.blockDim;
|
||||
mlaTilingData.mm1.enLoadAllAmat = tilingData->mm1.enLoadAllAmat;
|
||||
mlaTilingData.mm1.b0matPingPongBufferLen = tilingData->mm1.b0matPingPongBufferLen;
|
||||
|
||||
mlaTilingData.mm2.numBatch = tilingData->mm2.numBatch;
|
||||
mlaTilingData.mm2.m = tilingData->mm2.m;
|
||||
mlaTilingData.mm2.k = tilingData->mm2.k;
|
||||
mlaTilingData.mm2.n = tilingData->mm2.n;
|
||||
mlaTilingData.mm2.m0 = tilingData->mm2.m0;
|
||||
mlaTilingData.mm2.k0 = tilingData->mm2.k0;
|
||||
mlaTilingData.mm2.n0 = tilingData->mm2.n0;
|
||||
mlaTilingData.mm2.mLoop = tilingData->mm2.mLoop;
|
||||
mlaTilingData.mm2.kLoop = tilingData->mm2.kLoop;
|
||||
mlaTilingData.mm2.nLoop = tilingData->mm2.nLoop;
|
||||
mlaTilingData.mm2.coreLoop = tilingData->mm2.coreLoop;
|
||||
mlaTilingData.mm2.swizzleCount = tilingData->mm2.swizzleCount;
|
||||
mlaTilingData.mm2.enShuffleK = tilingData->mm2.enShuffleK;
|
||||
mlaTilingData.mm2.blockDim = tilingData->mm2.blockDim;
|
||||
mlaTilingData.mm2.enLoadAllAmat = tilingData->mm2.enLoadAllAmat;
|
||||
mlaTilingData.mm2.b0matPingPongBufferLen = tilingData->mm2.b0matPingPongBufferLen;
|
||||
|
||||
mlaTilingData.mm3.numBatch = tilingData->mm3.numBatch;
|
||||
mlaTilingData.mm3.m = tilingData->mm3.m;
|
||||
mlaTilingData.mm3.k = tilingData->mm3.k;
|
||||
mlaTilingData.mm3.n = tilingData->mm3.n;
|
||||
mlaTilingData.mm3.m0 = tilingData->mm3.m0;
|
||||
mlaTilingData.mm3.k0 = tilingData->mm3.k0;
|
||||
mlaTilingData.mm3.n0 = tilingData->mm3.n0;
|
||||
mlaTilingData.mm3.mLoop = tilingData->mm3.mLoop;
|
||||
mlaTilingData.mm3.kLoop = tilingData->mm3.kLoop;
|
||||
mlaTilingData.mm3.nLoop = tilingData->mm3.nLoop;
|
||||
mlaTilingData.mm3.coreLoop = tilingData->mm3.coreLoop;
|
||||
mlaTilingData.mm3.swizzleCount = tilingData->mm3.swizzleCount;
|
||||
mlaTilingData.mm3.enShuffleK = tilingData->mm3.enShuffleK;
|
||||
mlaTilingData.mm3.blockDim = tilingData->mm3.blockDim;
|
||||
|
||||
mlaTilingData.perTaskNum = tilingData->perTaskNum;
|
||||
mlaTilingData.resTaskNum = tilingData->resTaskNum;
|
||||
mlaTilingData.numCore = tilingData->numCore;
|
||||
|
||||
mlaTilingData.rmsNumCore1 = tilingData->rmsNumCore1;
|
||||
mlaTilingData.rmsNumCol1 = tilingData->rmsNumCol1;
|
||||
mlaTilingData.rmsNumCore2 = tilingData->rmsNumCore2;
|
||||
mlaTilingData.rmsNumCol2 = tilingData->rmsNumCol2;
|
||||
|
||||
mlaTilingData.hiddenSizeQ = tilingData->hiddenSizeQ;
|
||||
mlaTilingData.headNumQ = tilingData->headNumQ;
|
||||
mlaTilingData.headDim = tilingData->headDim;
|
||||
mlaTilingData.concatSize = tilingData->concatSize;
|
||||
mlaTilingData.rotaryCoeff = tilingData->rotaryCoeff;
|
||||
mlaTilingData.ntokens = tilingData->ntokens;
|
||||
mlaTilingData.realCore = tilingData->realCore;
|
||||
mlaTilingData.nlCoreRun = tilingData->nlCoreRun;
|
||||
mlaTilingData.lCoreRun = tilingData->lCoreRun;
|
||||
mlaTilingData.maxNPerLoopForUb = tilingData->maxNPerLoopForUb;
|
||||
mlaTilingData.preCoreLoopTime = tilingData->preCoreLoopTime;
|
||||
mlaTilingData.preCoreLoopNLast = tilingData->preCoreLoopNLast;
|
||||
mlaTilingData.lastCoreLoopTime = tilingData->lastCoreLoopTime;
|
||||
mlaTilingData.lastCoreLoopNLast = tilingData->lastCoreLoopNLast;
|
||||
|
||||
mlaTilingData.esqFrontCore = tilingData->esqFrontCore;
|
||||
mlaTilingData.esqTailCore = tilingData->esqTailCore;
|
||||
mlaTilingData.esqFrontCoreBatch = tilingData->esqFrontCoreBatch;
|
||||
mlaTilingData.esqTailCoreBatch = tilingData->esqTailCoreBatch;
|
||||
mlaTilingData.esqHeadNum = tilingData->esqHeadNum;
|
||||
mlaTilingData.esqColNum = tilingData->esqColNum;
|
||||
mlaTilingData.esqUbHeadLoop = tilingData->esqUbHeadLoop;
|
||||
mlaTilingData.esqHeadPerLoop = tilingData->esqHeadPerLoop;
|
||||
mlaTilingData.esqHeadTail = tilingData->esqHeadTail;
|
||||
mlaTilingData.esqColLoop = tilingData->esqColLoop;
|
||||
mlaTilingData.esqColTail = tilingData->esqColTail;
|
||||
|
||||
mlaTilingData.s1Offset = tilingData->s1Offset;
|
||||
mlaTilingData.s2Offset = tilingData->s2Offset;
|
||||
mlaTilingData.s3Offset = tilingData->s3Offset;
|
||||
mlaTilingData.s4Offset = tilingData->s4Offset;
|
||||
mlaTilingData.s5Offset = tilingData->s5Offset;
|
||||
|
||||
GM_ADDR s1 = workspace + static_cast<uint64_t>(mlaTilingData.s1Offset);
|
||||
GM_ADDR s2 = workspace + static_cast<uint64_t>(mlaTilingData.s2Offset);
|
||||
GM_ADDR s3 = workspace + static_cast<uint64_t>(mlaTilingData.s3Offset);
|
||||
GM_ADDR s4 = workspace + static_cast<uint64_t>(mlaTilingData.s4Offset);
|
||||
GM_ADDR s5 = workspace + static_cast<uint64_t>(mlaTilingData.s5Offset);
|
||||
|
||||
switch (mlaTilingData.tilingKey) {
|
||||
case KEY_FP16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KVCACHE, DataFormat::NZ, DataFormat::NZ, DataFormat::ND> opFp16Cm0Qm0(
|
||||
mlaTilingData, tiling);
|
||||
opFp16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opFp16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opFp16Cm0Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_FP16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_FP16::MLAOperation<CACHE_MODE_KROPE_CTKV, DataFormat::NZ, DataFormat::NZ, DataFormat::ND>
|
||||
opFp16Cm1Qm0(mlaTilingData, tiling);
|
||||
opFp16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3);
|
||||
if ASCEND_IS_AIC {
|
||||
opFp16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opFp16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_0_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 0, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm0Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm0Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm0Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm0Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_1_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 1, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm1Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm1Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm1Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm1Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
case KEY_BF16_CACHEMODE_3_QUANTMODE_0: {
|
||||
MLAPO_BF16::MLAOperation<__bf16, 3, DataFormat::NZ, DataFormat::NZ, DataFormat::ND,
|
||||
QuantMode::PER_TENSOR_ASYMM_QUANT>
|
||||
opBf16Cm3Qm0(mlaTilingData, tiling);
|
||||
opBf16Cm3Qm0.Init(hiddenState, gamma1, beta1, quantScale1, quantOffset1, wdqkv, bias1, gamma2, beta2,
|
||||
quantScale2, quantOffset2, gamma3, sin1, cos1, sin2, cos2, keycache, slotMapping, wuq,
|
||||
bias2, wuk, descale1, descale2, ctkvScale, qnopeScale, q, keycacheOut, q2, keycacheOut2,
|
||||
s1, s2, s3, s4, s5);
|
||||
if ASCEND_IS_AIC {
|
||||
opBf16Cm3Qm0.ProcessCube();
|
||||
}
|
||||
if ASCEND_IS_AIV {
|
||||
opBf16Cm3Qm0.ProcessVector();
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
extern void mla_preprocess_impl(
|
||||
void* stream,
|
||||
void* hidden_state,
|
||||
void* gamma1,
|
||||
void* beta1,
|
||||
void* quant_scale1,
|
||||
void* quant_offset1,
|
||||
void* wdqkv,
|
||||
void* bias1,
|
||||
void* gamma2,
|
||||
void* beta2,
|
||||
void* quant_scale2,
|
||||
void* quant_offset2,
|
||||
void* gamma3,
|
||||
void* sin1,
|
||||
void* cos1,
|
||||
void* sin2,
|
||||
void* cos2,
|
||||
void* keycache,
|
||||
void* slot_mapping,
|
||||
void* wuq,
|
||||
void* bias2,
|
||||
void* wuk,
|
||||
void* descale1,
|
||||
void* descale2,
|
||||
void* ctkv_scale,
|
||||
void* qnope_scale,
|
||||
void* q,
|
||||
void* keycache_out,
|
||||
void* q2,
|
||||
void* keycache_out2,
|
||||
void* workspace,
|
||||
void* tiling,
|
||||
const uint32_t block_dim)
|
||||
{
|
||||
mla_preprocess<<<block_dim, nullptr, stream>>>(
|
||||
hidden_state,
|
||||
gamma1,
|
||||
beta1,
|
||||
quant_scale1,
|
||||
quant_offset1,
|
||||
wdqkv,
|
||||
bias1,
|
||||
gamma2,
|
||||
beta2,
|
||||
quant_scale2,
|
||||
quant_offset2,
|
||||
gamma3,
|
||||
sin1,
|
||||
cos1,
|
||||
sin2,
|
||||
cos2,
|
||||
keycache,
|
||||
slot_mapping,
|
||||
wuq,
|
||||
bias2,
|
||||
wuk,
|
||||
descale1,
|
||||
descale2,
|
||||
ctkv_scale,
|
||||
qnope_scale,
|
||||
q,
|
||||
keycache_out,
|
||||
q2,
|
||||
keycache_out2,
|
||||
workspace,
|
||||
tiling);
|
||||
}
|
||||
|
||||
}
|
2918
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Normal file
2918
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_bf16.hpp
Normal file
File diff suppressed because it is too large
Load Diff
2508
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Normal file
2508
csrc/mla_preprocess/op_kernel/mla_preprocess_mix_fp16.hpp
Normal file
File diff suppressed because it is too large
Load Diff
36
csrc/ops.h
36
csrc/ops.h
@ -124,4 +124,40 @@ namespace vllm_ascend {
|
||||
uint32_t output_hidden_dim,
|
||||
uint32_t slice_offset,
|
||||
uint32_t output_full_dim);
|
||||
|
||||
extern void mla_preprocess_impl(
|
||||
void* stream,
|
||||
void* hidden_state,
|
||||
void* gamma1,
|
||||
void* beta1,
|
||||
void* quant_scale1,
|
||||
void* quant_offset1,
|
||||
void* wdqkv,
|
||||
void* bias1,
|
||||
void* gamma2,
|
||||
void* beta2,
|
||||
void* quant_scale2,
|
||||
void* quant_offset2,
|
||||
void* gamma3,
|
||||
void* sin1,
|
||||
void* cos1,
|
||||
void* sin2,
|
||||
void* cos2,
|
||||
void* keycache,
|
||||
void* slot_mapping,
|
||||
void* wuq,
|
||||
void* bias2,
|
||||
void* wuk,
|
||||
void* descale1,
|
||||
void* descale2,
|
||||
void* ctkv_scale,
|
||||
void* qnope_scale,
|
||||
void* q,
|
||||
void* keycache_out,
|
||||
void* q2,
|
||||
void* keycache_out2,
|
||||
void* workspace,
|
||||
void* tiling,
|
||||
const uint32_t block_dim
|
||||
);
|
||||
}
|
||||
|
@ -23,6 +23,7 @@
|
||||
#include "acl/acl.h"
|
||||
#include "ops.h"
|
||||
#include "utils.h"
|
||||
#include "mla_preprocess/op_host/mla_preprocess.h"
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
@ -106,6 +107,83 @@ std::tuple<at::Tensor, at::Tensor> rotary_embedding(at::Tensor &positions, at::T
|
||||
return {query_dst, key_dst};
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState, const at::Tensor &gamma0, const at::Tensor &beta0, const at::Tensor &wdqkv,
|
||||
const at::Tensor &descale0, const at::Tensor &gamma1, const at::Tensor &beta1, const at::Tensor &wuq,
|
||||
const at::Tensor &descale1, const at::Tensor &gamma2, const at::Tensor &cos, const at::Tensor &sin,
|
||||
const at::Tensor &wuk, const at::Tensor &kv_cache, const at::Tensor &kv_cache_rope, const at::Tensor &slotmapping,
|
||||
const at::Tensor &quant_scale0, const at::Tensor &quant_offset0, const at::Tensor &bias0,
|
||||
const at::Tensor &quant_scale1, const at::Tensor &quant_offset1, const at::Tensor &bias1,
|
||||
const c10::optional<at::Tensor> &ctkv_scale, const c10::optional<at::Tensor> &q_nope_scale,
|
||||
c10::optional<c10::string_view> cache_mode, c10::optional<c10::string_view> quant_mode, at::Tensor &q_out0,
|
||||
at::Tensor &kv_cache_out0, at::Tensor &q_out1, at::Tensor &kv_cache_out1)
|
||||
{
|
||||
at::Tensor CtkvScale =
|
||||
ctkv_scale.has_value()
|
||||
? ctkv_scale.value()
|
||||
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
||||
at::Tensor QnopeScale =
|
||||
q_nope_scale.has_value()
|
||||
? q_nope_scale.value()
|
||||
: at::empty({1}, at::TensorOptions().dtype(at::kHalf).device(hiddenState.options().device()));
|
||||
|
||||
auto [workspace_tensor, tiling, block_dim] = mlapo::mla_preprocess_tiling(
|
||||
hiddenState,
|
||||
wuk,
|
||||
cache_mode,
|
||||
quant_mode
|
||||
);
|
||||
|
||||
void *hidden_state_ptr = hiddenState.data_ptr();
|
||||
void *gamma0_ptr = gamma0.data_ptr();
|
||||
void *beta0_ptr = beta0.data_ptr();
|
||||
void *quant_scale0_ptr = quant_scale0.data_ptr();
|
||||
void *quant_offset0_ptr = quant_offset0.data_ptr();
|
||||
void *wdqkv_ptr = wdqkv.data_ptr();
|
||||
void *bias0_ptr = bias0.data_ptr();
|
||||
void *gamma1_ptr = gamma1.data_ptr();
|
||||
void *beta1_ptr = beta1.data_ptr();
|
||||
void *quant_scale1_ptr = quant_scale1.data_ptr();
|
||||
void *quant_offset1_ptr = quant_offset1.data_ptr();
|
||||
void *gamma2_ptr = gamma2.data_ptr();
|
||||
void *sin_ptr = sin.data_ptr();
|
||||
void *cos_ptr = cos.data_ptr();
|
||||
void *kv_cache_ptr = kv_cache.data_ptr();
|
||||
void *slotmapping_ptr = slotmapping.data_ptr();
|
||||
void *wuq_ptr = wuq.data_ptr();
|
||||
void *bias1_ptr = bias1.data_ptr();
|
||||
void *wuk_ptr = wuk.data_ptr();
|
||||
void *descale0_ptr = descale0.data_ptr();
|
||||
void *descale1_ptr = descale1.data_ptr();
|
||||
void *ctkv_scale_ptr = CtkvScale.data_ptr();
|
||||
void *qnope_scale_ptr = QnopeScale.data_ptr();
|
||||
void *q_out0_ptr = q_out0.data_ptr();
|
||||
void *kv_cache_out0_ptr = kv_cache_out0.data_ptr();
|
||||
void *q_out1_ptr = q_out1.data_ptr();
|
||||
void *kv_cache_out1_ptr = kv_cache_out1.data_ptr();
|
||||
void *workspace_ptr = workspace_tensor.data_ptr();
|
||||
void *tiling_ptr = tiling.data_ptr();
|
||||
|
||||
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("mla_preprocess");
|
||||
|
||||
cmd.SetCustomHandler([stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
tiling_ptr, block_dim]() -> int {
|
||||
mla_preprocess_impl(stream, hidden_state_ptr, gamma0_ptr, beta0_ptr, quant_scale0_ptr, quant_offset0_ptr, wdqkv_ptr, bias0_ptr,
|
||||
gamma1_ptr, beta1_ptr, quant_scale1_ptr, quant_offset1_ptr, gamma2_ptr, sin_ptr, cos_ptr, sin_ptr, cos_ptr,
|
||||
kv_cache_ptr, slotmapping_ptr, wuq_ptr, bias1_ptr, wuk_ptr, descale0_ptr, descale1_ptr, ctkv_scale_ptr,
|
||||
qnope_scale_ptr, q_out0_ptr, kv_cache_out0_ptr, q_out1_ptr, kv_cache_out1_ptr, workspace_ptr,
|
||||
tiling_ptr, block_dim);
|
||||
return 0;
|
||||
});
|
||||
cmd.Run();
|
||||
return std::forward_as_tuple(q_out0, kv_cache_out0, q_out1, kv_cache_out1);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> get_masked_input_and_mask(
|
||||
at::Tensor &input,
|
||||
const int64_t org_vocab_start_index,
|
||||
@ -422,4 +500,17 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
|
||||
"sgmv_expand(Tensor! x, Tensor! weight, Tensor! lora_indices, Tensor! seq_len, Tensor! y,"
|
||||
" int slice_offset, int slice_size) -> Tensor");
|
||||
ops.impl("sgmv_expand", torch::kPrivateUse1, &vllm_ascend::sgmv_expand);
|
||||
|
||||
ops.def(
|
||||
"mla_preprocess(Tensor hiddenState, Tensor gamma0, Tensor beta0, Tensor wdqkv,"
|
||||
" Tensor descale0, Tensor gamma1, Tensor beta1, Tensor wuq, Tensor descale1,"
|
||||
" Tensor gamma2, Tensor cos, Tensor sin, Tensor wuk, Tensor kv_cache,"
|
||||
" Tensor kv_cache_rope, Tensor slotmapping, Tensor quant_scale0,"
|
||||
" Tensor quant_offset0, Tensor bias0, Tensor quant_scale1, Tensor quant_offset1,"
|
||||
" Tensor bias1, Tensor? ctkv_scale, Tensor? q_nope_scale, str? cache_mode,"
|
||||
" str? quant_mode, Tensor! q_out0, Tensor! kv_cache_out0, Tensor! q_out1,"
|
||||
" Tensor! kv_cache_out1) -> (Tensor q_out0, Tensor kv_cache_out0,"
|
||||
" Tensor q_out1, Tensor kv_cache_out1)"
|
||||
);
|
||||
ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess);
|
||||
}
|
||||
|
@ -81,6 +81,41 @@ at::Tensor sgmv_expand_meta(at::Tensor &x, at::Tensor &weight, at::Tensor &lora_
|
||||
return y_out;
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor &, at::Tensor &, at::Tensor &, at::Tensor &> mla_preprocess(
|
||||
const at::Tensor &hiddenState,
|
||||
const at::Tensor &gamma0,
|
||||
const at::Tensor &beta0,
|
||||
const at::Tensor &wdqkv,
|
||||
const at::Tensor &descale0,
|
||||
const at::Tensor &gamma1,
|
||||
const at::Tensor &beta1,
|
||||
const at::Tensor &wuq,
|
||||
const at::Tensor &descale1,
|
||||
const at::Tensor &gamma2,
|
||||
const at::Tensor &cos,
|
||||
const at::Tensor &sin,
|
||||
const at::Tensor &wuk,
|
||||
const at::Tensor &kv_cache,
|
||||
const at::Tensor &kv_cache_rope,
|
||||
const at::Tensor &slotmapping,
|
||||
const at::Tensor &quant_scale0,
|
||||
const at::Tensor &quant_offset0,
|
||||
const at::Tensor &bias0,
|
||||
const at::Tensor &quant_scale1,
|
||||
const at::Tensor &quant_offset1,
|
||||
const at::Tensor &bias1,
|
||||
const c10::optional<at::Tensor> &ctkv_scale,
|
||||
const c10::optional<at::Tensor> &q_nope_scale,
|
||||
c10::optional<c10::string_view> cache_mode,
|
||||
c10::optional<c10::string_view> quant_mode,
|
||||
at::Tensor &q_out0,
|
||||
at::Tensor &kv_cache_out0,
|
||||
at::Tensor &q_out1,
|
||||
at::Tensor &kv_cache_out1)
|
||||
{
|
||||
return {q_out0, kv_cache_out0, q_out1, kv_cache_out1};
|
||||
}
|
||||
|
||||
|
||||
} // namespace meta
|
||||
} // namespace vllm_ascend
|
||||
@ -97,6 +132,7 @@ namespace {
|
||||
ops.impl("bgmv_expand", &vllm_ascend::meta::bgmv_expand_meta);
|
||||
// Sgmv expand
|
||||
ops.impl("sgmv_expand", &vllm_ascend::meta::sgmv_expand_meta);
|
||||
|
||||
// MLA preprocess
|
||||
ops.impl("mla_preprocess", &vllm_ascend::meta::mla_preprocess);
|
||||
}
|
||||
}
|
||||
|
112
tests/e2e/singlecard/ops/test_mla_preprocess.py
Normal file
112
tests/e2e/singlecard/ops/test_mla_preprocess.py
Normal file
@ -0,0 +1,112 @@
|
||||
import gc
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import enable_custom_op
|
||||
|
||||
enable_custom_op()
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_mla_preprocess_kernel():
|
||||
token_num = 1
|
||||
head_num = 2
|
||||
N_7168 = 7168
|
||||
block_num = 1
|
||||
block_size = 128
|
||||
dtype = torch.bfloat16
|
||||
|
||||
hidden_states = torch.randn((token_num, N_7168), dtype=dtype).npu()
|
||||
gamma0 = torch.randn((N_7168), dtype=dtype).npu()
|
||||
beta0 = torch.randn((N_7168), dtype=dtype).npu()
|
||||
quant_scale0 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset0 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wdqkv = torch.randint(0, 7, (1, 224, 2112, 32), dtype=torch.int8).npu()
|
||||
wdqkv = torch_npu.npu_format_cast(wdqkv.contiguous(), 29)
|
||||
|
||||
de_scale0 = torch.rand((2112, ), dtype=torch.float).npu()
|
||||
bias0 = torch.randint(0, 7, (2112, ), dtype=torch.int32).npu()
|
||||
gamma1 = torch.randn((1536), dtype=dtype).npu()
|
||||
beta1 = torch.randn((1536), dtype=dtype).npu()
|
||||
quant_scale1 = torch.randn((1, ), dtype=dtype).npu()
|
||||
quant_offset1 = torch.randint(0, 7, (1, ), dtype=torch.int8).npu()
|
||||
|
||||
wuq = torch.randint(0, 7, (1, 48, head_num * 192, 32),
|
||||
dtype=torch.int8).npu()
|
||||
wuq = torch_npu.npu_format_cast(wuq.contiguous(), 29)
|
||||
|
||||
de_scale1 = torch.rand((head_num * 192, ), dtype=torch.float).npu()
|
||||
bias1 = torch.randint(0, 7, (head_num * 192, ), dtype=torch.int32).npu()
|
||||
|
||||
gamma2 = torch.randn((512), dtype=dtype).npu()
|
||||
|
||||
cos = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
sin = torch.randn((token_num, 64), dtype=dtype).npu()
|
||||
|
||||
wuk = torch.randn((head_num, 128, 512), dtype=dtype).npu()
|
||||
wuk = torch_npu.npu_format_cast(wuk, 29)
|
||||
kv_cache = torch.randint(0,
|
||||
7,
|
||||
(block_num, head_num * 512 // 32, block_size, 32),
|
||||
dtype=dtype).npu()
|
||||
kv_cache_rope = torch.randn(
|
||||
(block_num, head_num * 64 // 16, block_size, 16), dtype=dtype).npu()
|
||||
|
||||
slotmapping = torch.randint(0, 7, (token_num, ), dtype=torch.int32).npu()
|
||||
|
||||
ctkv_scale = torch.randn((1, ), dtype=dtype).npu()
|
||||
qnope_scale = torch.randn((head_num), dtype=dtype).npu()
|
||||
|
||||
q_nope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_rope_out = torch.empty(
|
||||
(hidden_states.shape[0], wuk.shape[0], kv_cache_rope.shape[-1]),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
q_nope_old = q_nope_out.clone()
|
||||
q_rope_old = q_rope_out.clone()
|
||||
|
||||
torch.ops._C_ascend.mla_preprocess(
|
||||
hidden_states,
|
||||
gamma0,
|
||||
beta0,
|
||||
wdqkv,
|
||||
de_scale0,
|
||||
gamma1,
|
||||
beta1,
|
||||
wuq,
|
||||
de_scale1,
|
||||
gamma2,
|
||||
cos,
|
||||
sin,
|
||||
wuk,
|
||||
kv_cache,
|
||||
kv_cache_rope,
|
||||
slotmapping,
|
||||
quant_scale0=quant_scale0,
|
||||
quant_offset0=quant_offset0,
|
||||
bias0=bias0,
|
||||
quant_scale1=quant_scale1,
|
||||
quant_offset1=quant_offset1,
|
||||
bias1=bias1,
|
||||
ctkv_scale=ctkv_scale,
|
||||
q_nope_scale=qnope_scale,
|
||||
cache_mode="krope_ctkv",
|
||||
quant_mode="per_tensor_quant_asymm",
|
||||
q_out0=q_nope_out,
|
||||
kv_cache_out0=kv_cache,
|
||||
q_out1=q_rope_out,
|
||||
kv_cache_out1=kv_cache_rope,
|
||||
)
|
||||
assert not torch.equal(q_nope_out, q_nope_old)
|
||||
assert not torch.equal(q_rope_out, q_rope_old)
|
||||
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
torch.npu.reset_peak_memory_stats()
|
Reference in New Issue
Block a user