mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-21 06:10:33 +08:00
### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
379 lines
16 KiB
C++
379 lines
16 KiB
C++
/*
|
|
* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
|
|
*/
|
|
|
|
#include "kernel_operator.h"
|
|
#include "kernel_tensor_impl.h"
|
|
#include "kernel_type.h"
|
|
#include "types.h"
|
|
#include "utils.h"
|
|
using vllm_ascend::AccType;
|
|
|
|
template<typename scalar_t>
|
|
class GetMaskedInputAndMask {
|
|
public:
|
|
__aicore__ inline GetMaskedInputAndMask() {}
|
|
|
|
__aicore__ inline ~GetMaskedInputAndMask() {
|
|
pipe.Reset();
|
|
}
|
|
|
|
|
|
__aicore__ inline void Init(
|
|
__gm__ scalar_t* input,
|
|
__gm__ scalar_t* masked_input,
|
|
__gm__ bool* mask_out,
|
|
const int64_t org_vocab_start_index,
|
|
const int64_t org_vocab_end_index,
|
|
const int64_t num_org_vocab_padding,
|
|
const int64_t added_vocab_start_index,
|
|
const int64_t added_vocab_end_index,
|
|
const int64_t size)
|
|
{
|
|
// Initialize basic parameters
|
|
input_ = input;
|
|
masked_input_ = masked_input;
|
|
mask_out_ = mask_out;
|
|
org_vocab_start_index_ = org_vocab_start_index;
|
|
org_vocab_end_index_ = org_vocab_end_index;
|
|
size_ = ((size + 31) / 32) * 32;
|
|
added_offset_ = added_vocab_start_index -
|
|
(org_vocab_end_index - org_vocab_start_index) -
|
|
num_org_vocab_padding;
|
|
added_vocab_start_index_ = added_vocab_start_index;
|
|
added_vocab_end_index_ = added_vocab_end_index;
|
|
|
|
// Initialize global tensors
|
|
inputGlobal.SetGlobalBuffer(input);
|
|
maskedOutputGlobal.SetGlobalBuffer(masked_input);
|
|
maskOutGlobal.SetGlobalBuffer(mask_out);
|
|
|
|
// Initialize queues
|
|
pipe.InitBuffer(inQueue, 1, size_ * sizeof(scalar_t));
|
|
pipe.InitBuffer(outQueue, 1, size_ * sizeof(scalar_t));
|
|
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));
|
|
|
|
// Initialize calculation buffers
|
|
// NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
|
|
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
|
|
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));
|
|
|
|
// Initialize result queues
|
|
pipe.InitBuffer(result_ge_que, BUFFER_NUM, size_ * sizeof(float));
|
|
pipe.InitBuffer(result_le_que, BUFFER_NUM, size_ * sizeof(float));
|
|
pipe.InitBuffer(result_org_mask_que, BUFFER_NUM, size_ * sizeof(float));
|
|
pipe.InitBuffer(result_add_mask_que, BUFFER_NUM, size_ * sizeof(float));
|
|
|
|
// Initialize temporary buffers
|
|
pipe.InitBuffer(start_buf, size_ * sizeof(float));
|
|
pipe.InitBuffer(end_buf, size_ * sizeof(float));
|
|
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting
|
|
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
|
|
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
|
|
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
|
|
}
|
|
|
|
__aicore__ inline void Process()
|
|
{
|
|
CopyIn();
|
|
Compute();
|
|
CopyOut();
|
|
}
|
|
|
|
private:
|
|
__aicore__ inline void CopyIn()
|
|
{
|
|
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.AllocTensor<scalar_t>();
|
|
AscendC::DataCopy(inputLocal, inputGlobal, size_);
|
|
inQueue.EnQue(inputLocal);
|
|
}
|
|
|
|
__aicore__ inline void CompareWithValue(
|
|
AscendC::LocalTensor<int8_t>& result,
|
|
const AscendC::LocalTensor<float>& input,
|
|
const AscendC::LocalTensor<float>& compare_value,
|
|
bool is_greater_equal) {
|
|
|
|
AscendC::LocalTensor<float> compute_buf = calc_buf_1.Get<float>();
|
|
if (is_greater_equal) {
|
|
AscendC::Max(compute_buf, input, compare_value, size_);
|
|
AscendC::Sub(compute_buf, compare_value, compute_buf, size_);
|
|
} else {
|
|
AscendC::Max(compute_buf, input, compare_value, size_);
|
|
AscendC::Sub(compute_buf, compute_buf, compare_value, size_);
|
|
}
|
|
|
|
AscendC::Abs(compute_buf, compute_buf, size_);
|
|
AscendC::Mins(compute_buf, compute_buf, MIN_ACCURACY_FP32, size_);
|
|
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
|
|
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_1_FP32, size_);
|
|
AscendC::Muls(compute_buf, compute_buf, MAX_MUL_2_FP32, size_);
|
|
AscendC::Adds(compute_buf, compute_buf, NEGATIVE_ONE_FP32, size_);
|
|
AscendC::Abs(compute_buf, compute_buf, size_);
|
|
|
|
AscendC::LocalTensor<half> compute_buf_fp16 = calc_buf_2.Get<half>();
|
|
AscendC::Cast(compute_buf_fp16, compute_buf, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(result, compute_buf_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
|
}
|
|
|
|
__aicore__ inline void ComputeRangeMask(
|
|
AscendC::LocalTensor<int8_t>& range_mask,
|
|
const AscendC::LocalTensor<float>& input,
|
|
const float start_value,
|
|
const float end_value) {
|
|
|
|
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
|
|
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();
|
|
|
|
AscendC::Duplicate(start_value_tensor, start_value, size_);
|
|
AscendC::Duplicate(end_value_tensor, end_value, size_);
|
|
|
|
AscendC::LocalTensor<int8_t> ge_result = result_ge_que.AllocTensor<int8_t>();
|
|
AscendC::LocalTensor<int8_t> lt_result = result_le_que.AllocTensor<int8_t>();
|
|
|
|
CompareWithValue(ge_result, start_value_tensor, input, true);
|
|
CompareWithValue(lt_result, input, end_value_tensor, false);
|
|
|
|
#if (__CCE_AICORE__ >= 220)
|
|
AscendC::And(range_mask, ge_result, lt_result, size_);
|
|
#else
|
|
{
|
|
// WORKAROUND for older arch
|
|
// No direct int8->int16 cast. Use half as intermediate.
|
|
// No direct int8 And. Use int16 And.
|
|
AscendC::LocalTensor<int16_t> ge_result_i16 = calc_buf_1.Get<int16_t>();
|
|
AscendC::LocalTensor<int16_t> lt_result_i16 = calc_buf_2.Get<int16_t>();
|
|
AscendC::LocalTensor<int16_t> range_mask_i16 = ge_result_i16;
|
|
|
|
// Use a temporary buffer for half type
|
|
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
|
|
|
|
// 1. Cast inputs: int8_t -> half -> int16_t
|
|
AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
// 2. Perform And on int16_t tensors
|
|
AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_);
|
|
|
|
// 3. Cast result back: int16_t -> half -> int8_t
|
|
AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
}
|
|
#endif
|
|
}
|
|
|
|
__aicore__ inline void Compute() {
|
|
AscendC::LocalTensor<scalar_t> inputLocal = inQueue.DeQue<scalar_t>();
|
|
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.AllocTensor<scalar_t>();
|
|
AscendC::LocalTensor<int8_t> maskLocal = maskQueue.AllocTensor<int8_t>();
|
|
|
|
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
|
|
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
|
|
ComputeRangeMask(orgVocabMask,
|
|
inputFloat,
|
|
static_cast<float>(org_vocab_start_index_),
|
|
static_cast<float>(org_vocab_end_index_));
|
|
|
|
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
|
|
ComputeRangeMask(addedVocabMask,
|
|
inputFloat,
|
|
static_cast<float>(added_vocab_start_index_),
|
|
static_cast<float>(added_vocab_end_index_));
|
|
|
|
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
|
|
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();
|
|
|
|
AscendC::Duplicate(constOrgStartIndex, float(org_vocab_start_index_), size_);
|
|
|
|
AscendC::LocalTensor<half> orgVocabMask_fp16;
|
|
AscendC::LocalTensor<float> orgVocabMask_fp32;
|
|
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);
|
|
|
|
AscendC::LocalTensor<float> addedOffset;
|
|
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
|
|
AscendC::Duplicate(addedOffsetTensor, float(added_offset_), size_);
|
|
|
|
AscendC::LocalTensor<half> addedVocabMask_fp16;
|
|
AscendC::LocalTensor<float> addedVocabMask_fp32;
|
|
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
|
|
AscendC::Add(validOffset, validOffset, addedOffset, size_);
|
|
|
|
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();
|
|
|
|
#if (__CCE_AICORE__ >= 220)
|
|
AscendC::Or(vocabMask,
|
|
orgVocabMask,
|
|
addedVocabMask,
|
|
size_);
|
|
#else
|
|
{
|
|
// WORKAROUND for older arch
|
|
// No direct int8->int16 cast. Use half as intermediate.
|
|
// No direct int8 Or. Use int16 Or.
|
|
AscendC::LocalTensor<int16_t> orgVocabMask_i16 = calc_buf_1.Get<int16_t>();
|
|
AscendC::LocalTensor<int16_t> addedVocabMask_i16 = calc_buf_2.Get<int16_t>();
|
|
AscendC::LocalTensor<int16_t> vocabMask_i16 = orgVocabMask_i16;
|
|
|
|
// Use a temporary buffer for half type. inputFloat_buf is free now.
|
|
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
|
|
|
|
// 1. Cast inputs: int8_t -> half -> int16_t
|
|
AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
// 2. Perform Or on int16_t tensors
|
|
AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);
|
|
|
|
// 3. Cast result back: int16_t -> half -> int8_t
|
|
AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
|
}
|
|
#endif
|
|
|
|
AscendC::Sub(inputFloat, inputFloat, validOffset, size_);
|
|
|
|
AscendC::LocalTensor<half> vocabMask_fp16;
|
|
AscendC::LocalTensor<float> vocabMask_fp32;
|
|
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
|
|
|
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);
|
|
|
|
AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
|
|
outQueue.EnQue(maskedLocal);
|
|
|
|
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
|
|
AscendC::Duplicate(ones_tensor, (float)1, size_);
|
|
AscendC::LocalTensor<float> maskLocal_fp32;
|
|
|
|
AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);
|
|
|
|
AscendC::LocalTensor<half> maskLocal_fp16;
|
|
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
|
|
AscendC::Cast(maskLocal, maskLocal_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
|
maskQueue.EnQue(maskLocal);
|
|
inQueue.FreeTensor(inputLocal);
|
|
}
|
|
|
|
__aicore__ inline void CopyOut()
|
|
{
|
|
AscendC::LocalTensor<scalar_t> maskedLocal = outQueue.DeQue<scalar_t>();
|
|
AscendC::LocalTensor<bool> maskLocal = maskQueue.DeQue<bool>();
|
|
|
|
AscendC::DataCopy(maskedOutputGlobal, maskedLocal, size_);
|
|
AscendC::DataCopy(maskOutGlobal, maskLocal, size_);
|
|
|
|
outQueue.FreeTensor(maskedLocal);
|
|
maskQueue.FreeTensor(maskLocal);
|
|
}
|
|
|
|
private:
|
|
static constexpr int32_t BUFFER_NUM = 2;
|
|
AscendC::TPipe pipe;
|
|
AscendC::TQue<AscendC::TPosition::VECIN, 1> inQueue;
|
|
AscendC::TQue<AscendC::TPosition::VECOUT, 1> outQueue, maskQueue;
|
|
AscendC::GlobalTensor<scalar_t> inputGlobal, maskedOutputGlobal;
|
|
AscendC::GlobalTensor<bool> maskOutGlobal;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_1;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> calc_buf_2;
|
|
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_ge_que;
|
|
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_le_que;
|
|
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_org_mask_que;
|
|
AscendC::TQue<AscendC::QuePosition::VECOUT, BUFFER_NUM> result_add_mask_que;
|
|
|
|
// Temporary buffers
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
|
|
AscendC::TBuf<AscendC::TPosition::VECCALC> ones_buf_;
|
|
|
|
__gm__ scalar_t *input_, *masked_input_;
|
|
__gm__ bool *mask_out_;
|
|
int64_t size_;
|
|
int64_t org_vocab_start_index_, org_vocab_end_index_;
|
|
int64_t added_vocab_start_index_, added_vocab_end_index_;
|
|
int64_t added_offset_;
|
|
|
|
static constexpr float MIN_ACCURACY_FP32 = 1.1754943508222875e-38;
|
|
static constexpr float MAX_MUL_1_FP32 = 1125899906842624;
|
|
static constexpr float MAX_MUL_2_FP32 = 67108864;
|
|
static constexpr float NEGATIVE_ONE_FP32 = -1.0f;
|
|
};
|
|
|
|
extern "C" __global__ __aicore__ void get_masked_input_and_mask_kernel(
|
|
__gm__ int32_t* input,
|
|
__gm__ int32_t* masked_input,
|
|
__gm__ bool* mask_out,
|
|
const int64_t org_vocab_start_index,
|
|
const int64_t org_vocab_end_index,
|
|
const int64_t num_org_vocab_padding,
|
|
const int64_t added_vocab_start_index,
|
|
const int64_t added_vocab_end_index,
|
|
const int64_t size,
|
|
const uint32_t loop_cnt,
|
|
const uint32_t aiv_num)
|
|
{
|
|
{
|
|
GetMaskedInputAndMask<int32_t> op{};
|
|
|
|
for (int64_t i = AscendC::GetBlockIdx(); i < loop_cnt; i += aiv_num) {
|
|
op.Init(input + i * size/loop_cnt,
|
|
masked_input + i * size/loop_cnt,
|
|
mask_out + i * size/loop_cnt,
|
|
org_vocab_start_index, org_vocab_end_index,
|
|
num_org_vocab_padding, added_vocab_start_index,
|
|
added_vocab_end_index, size/loop_cnt);
|
|
|
|
op.Process();
|
|
}
|
|
} // op destructor called here
|
|
}
|
|
|
|
namespace vllm_ascend {
|
|
|
|
void get_masked_input_and_mask_impl(
|
|
void* stream,
|
|
void* input,
|
|
void* masked_input,
|
|
void* mask_out,
|
|
const int64_t org_vocab_start_index,
|
|
const int64_t org_vocab_end_index,
|
|
const int64_t num_org_vocab_padding,
|
|
const int64_t added_vocab_start_index,
|
|
const int64_t added_vocab_end_index,
|
|
const int64_t size,
|
|
const uint32_t loop_cnt,
|
|
const uint32_t aiv_num)
|
|
{
|
|
get_masked_input_and_mask_kernel<<<aiv_num, nullptr, stream>>>(
|
|
static_cast<int32_t*>(input),
|
|
static_cast<int32_t*>(masked_input),
|
|
static_cast<bool*>(mask_out),
|
|
org_vocab_start_index,
|
|
org_vocab_end_index,
|
|
num_org_vocab_padding,
|
|
added_vocab_start_index,
|
|
added_vocab_end_index,
|
|
size,
|
|
loop_cnt,
|
|
aiv_num);
|
|
}
|
|
|
|
} // namespace vllm_ascend
|