mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Platform] Add initial experimental support for Altlas 300I series (#1333)
### What this PR does / why we need it? Add initial experimental support for Ascend 310P, this patch squash below PR into one to help validation: - https://github.com/vllm-project/vllm-ascend/pull/914 - https://github.com/vllm-project/vllm-ascend/pull/1318 - https://github.com/vllm-project/vllm-ascend/pull/1327 ### Does this PR introduce _any_ user-facing change? User can run vLLM on Altlas 300I DUO series ### How was this patch tested? CI passed with: - E2E image build for 310P - CI test on A2 with e2e test and longterm test - Unit test missing because need a real 310P image to have the test, will add in a separate PR later. - Manually e2e test: - Qwen2.5-7b-instruct, Qwen2.5-0.5b, Qwen3-0.6B, Qwen3-4B, Qwen3-8B: https://github.com/vllm-project/vllm-ascend/pull/914#issuecomment-2942989322 - Pangu MGoE 72B The patch has been tested locally on Ascend 310P hardware to ensure that the changes do not break existing functionality and that the new features work as intended. #### ENV information CANN, NNAL version: 8.1.RC1 > [!IMPORTANT] > PTA 2.5.1 version >= torch_npu-2.5.1.post1.dev20250528 to support NZ format and calling NNAL operators on 310P #### Code example ##### Build vllm-ascend from source code ```shell # download source code as vllm-ascend cd vllm-ascend export SOC_VERSION=Ascend310P3 pip install -v -e . cd .. ``` ##### Run offline inference ```python from vllm import LLM, SamplingParams prompts = ["水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。", "水的沸点是100摄氏度吗?请回答是或者否。", "若腋下体温为38摄氏度,请问这人是否发烧?请回答是或者否。"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=10) # Create an LLM. llm = LLM( model="Qwen/Qwen2.5-7B-Instruct", max_model_len=4096, max_num_seqs=4, dtype="float16", # IMPORTANT cause some ATB ops cannot support bf16 on 310P disable_custom_all_reduce=True, trust_remote_code=True, tensor_parallel_size=2, compilation_config={"custom_ops":['none', "+rms_norm", "+rotary_embedding"]}, ) # Generate texts from the prompts. outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` --------- Signed-off-by: Vincent Yuan <farawayboat@gmail.com> Signed-off-by: Yikun Jiang <yikunkero@gmail.com> Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: Vincent Yuan <farawayboat@gmail.com> Co-authored-by: angazenn <zengyanjia@huawei.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Co-authored-by: leo-pony <nengjunma@outlook.com> Co-authored-by: shen-shanshan <467638484@qq.com>
This commit is contained in:
@ -54,6 +54,7 @@ public:
|
||||
pipe.InitBuffer(maskQueue, 1, size_ * sizeof(bool));
|
||||
|
||||
// Initialize calculation buffers
|
||||
// NOTE: calc_buf_1 and calc_buf_2 are also used for int16 casting on older archs.
|
||||
pipe.InitBuffer(calc_buf_1, size_ * sizeof(float));
|
||||
pipe.InitBuffer(calc_buf_2, size_ * sizeof(float));
|
||||
|
||||
@ -66,7 +67,7 @@ public:
|
||||
// Initialize temporary buffers
|
||||
pipe.InitBuffer(start_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(end_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(inputFloat_buf, size_ * sizeof(float)); // Also used for half intermediate in casting
|
||||
pipe.InitBuffer(validOffset_buf, size_ * sizeof(float));
|
||||
pipe.InitBuffer(vocabMask_buf_, size_ * sizeof(int8_t));
|
||||
pipe.InitBuffer(ones_buf_, size_ * sizeof(float));
|
||||
@ -121,7 +122,6 @@ private:
|
||||
const float start_value,
|
||||
const float end_value) {
|
||||
|
||||
// Use already initialized buffers
|
||||
AscendC::LocalTensor<float> start_value_tensor = start_buf.Get<float>();
|
||||
AscendC::LocalTensor<float> end_value_tensor = end_buf.Get<float>();
|
||||
|
||||
@ -134,7 +134,35 @@ private:
|
||||
CompareWithValue(ge_result, start_value_tensor, input, true);
|
||||
CompareWithValue(lt_result, input, end_value_tensor, false);
|
||||
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
AscendC::And(range_mask, ge_result, lt_result, size_);
|
||||
#else
|
||||
{
|
||||
// WORKAROUND for older arch
|
||||
// No direct int8->int16 cast. Use half as intermediate.
|
||||
// No direct int8 And. Use int16 And.
|
||||
AscendC::LocalTensor<int16_t> ge_result_i16 = calc_buf_1.Get<int16_t>();
|
||||
AscendC::LocalTensor<int16_t> lt_result_i16 = calc_buf_2.Get<int16_t>();
|
||||
AscendC::LocalTensor<int16_t> range_mask_i16 = ge_result_i16;
|
||||
|
||||
// Use a temporary buffer for half type
|
||||
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
|
||||
|
||||
// 1. Cast inputs: int8_t -> half -> int16_t
|
||||
AscendC::Cast(tmp_half, ge_result, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(ge_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Cast(tmp_half, lt_result, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(lt_result_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
// 2. Perform And on int16_t tensors
|
||||
AscendC::And(range_mask_i16, ge_result_i16, lt_result_i16, size_);
|
||||
|
||||
// 3. Cast result back: int16_t -> half -> int8_t
|
||||
AscendC::Cast(tmp_half, range_mask_i16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(range_mask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
__aicore__ inline void Compute() {
|
||||
@ -145,24 +173,18 @@ private:
|
||||
AscendC::LocalTensor<float> inputFloat = inputFloat_buf.Get<float>();
|
||||
AscendC::Cast(inputFloat, inputLocal, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
// Calculate mask for org_vocab range
|
||||
// org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
|
||||
AscendC::LocalTensor<int8_t> orgVocabMask = result_org_mask_que.AllocTensor<int8_t>();
|
||||
ComputeRangeMask(orgVocabMask,
|
||||
inputFloat,
|
||||
static_cast<float>(org_vocab_start_index_),
|
||||
static_cast<float>(org_vocab_end_index_));
|
||||
|
||||
// Calculate mask for added_vocab range
|
||||
// added_vocab_mask = (input_ >= added_vocab_start_index) & (input_ < added_vocab_end_index)
|
||||
AscendC::LocalTensor<int8_t> addedVocabMask = result_add_mask_que.AllocTensor<int8_t>();
|
||||
ComputeRangeMask(addedVocabMask,
|
||||
inputFloat,
|
||||
static_cast<float>(added_vocab_start_index_),
|
||||
static_cast<float>(added_vocab_end_index_));
|
||||
|
||||
// Calculate validOffset
|
||||
// valid_offset = (org_vocab_start_index * org_vocab_mask) + (added_offset * added_vocab_mask)
|
||||
AscendC::LocalTensor<float> validOffset = validOffset_buf.Get<float>();
|
||||
AscendC::LocalTensor<float> constOrgStartIndex = start_buf.Get<float>();
|
||||
|
||||
@ -173,10 +195,7 @@ private:
|
||||
AscendC::Cast(orgVocabMask_fp16, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(orgVocabMask_fp32, orgVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Mul(validOffset,
|
||||
constOrgStartIndex,
|
||||
orgVocabMask_fp32,
|
||||
size_);
|
||||
AscendC::Mul(validOffset, constOrgStartIndex, orgVocabMask_fp32, size_);
|
||||
|
||||
AscendC::LocalTensor<float> addedOffset;
|
||||
AscendC::LocalTensor<float> addedOffsetTensor = end_buf.Get<float>();
|
||||
@ -187,44 +206,61 @@ private:
|
||||
AscendC::Cast(addedVocabMask_fp16, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(addedVocabMask_fp32, addedVocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Mul(addedOffset,
|
||||
addedOffsetTensor,
|
||||
addedVocabMask_fp32,
|
||||
size_);
|
||||
|
||||
AscendC::Mul(addedOffset, addedOffsetTensor, addedVocabMask_fp32, size_);
|
||||
AscendC::Add(validOffset, validOffset, addedOffset, size_);
|
||||
|
||||
// vocab_mask = org_vocab_mask | added_vocab_mask
|
||||
AscendC::LocalTensor<int8_t> vocabMask = vocabMask_buf_.Get<int8_t>();
|
||||
|
||||
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
AscendC::Or(vocabMask,
|
||||
orgVocabMask,
|
||||
addedVocabMask,
|
||||
size_);
|
||||
|
||||
#else
|
||||
{
|
||||
// WORKAROUND for older arch
|
||||
// No direct int8->int16 cast. Use half as intermediate.
|
||||
// No direct int8 Or. Use int16 Or.
|
||||
AscendC::LocalTensor<int16_t> orgVocabMask_i16 = calc_buf_1.Get<int16_t>();
|
||||
AscendC::LocalTensor<int16_t> addedVocabMask_i16 = calc_buf_2.Get<int16_t>();
|
||||
AscendC::LocalTensor<int16_t> vocabMask_i16 = orgVocabMask_i16;
|
||||
|
||||
// Use a temporary buffer for half type. inputFloat_buf is free now.
|
||||
AscendC::LocalTensor<half> tmp_half = inputFloat_buf.Get<half>();
|
||||
|
||||
// 1. Cast inputs: int8_t -> half -> int16_t
|
||||
AscendC::Cast(tmp_half, orgVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(orgVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::Cast(tmp_half, addedVocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(addedVocabMask_i16, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
// 2. Perform Or on int16_t tensors
|
||||
AscendC::Or(vocabMask_i16, orgVocabMask_i16, addedVocabMask_i16, size_);
|
||||
|
||||
// 3. Cast result back: int16_t -> half -> int8_t
|
||||
AscendC::Cast(tmp_half, vocabMask_i16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(vocabMask, tmp_half, AscendC::RoundMode::CAST_NONE, size_);
|
||||
}
|
||||
#endif
|
||||
|
||||
AscendC::Sub(inputFloat, inputFloat, validOffset, size_);
|
||||
|
||||
// input_ = vocab_mask * (input_ - valid_offset)
|
||||
AscendC::LocalTensor<half> vocabMask_fp16;
|
||||
AscendC::LocalTensor<float> vocabMask_fp32;
|
||||
AscendC::Cast(vocabMask_fp16, vocabMask, AscendC::RoundMode::CAST_NONE, size_);
|
||||
AscendC::Cast(vocabMask_fp32, vocabMask_fp16, AscendC::RoundMode::CAST_NONE, size_);
|
||||
|
||||
AscendC::LocalTensor<float> inputFloat_fp32;
|
||||
AscendC::Mul(inputFloat, inputFloat, vocabMask_fp32, size_);
|
||||
|
||||
AscendC::Cast(maskedLocal, inputFloat, AscendC::RoundMode::CAST_CEIL, size_);
|
||||
outQueue.EnQue(maskedLocal);
|
||||
|
||||
// ~vocab_mask
|
||||
AscendC::LocalTensor<float> ones_tensor = ones_buf_.Get<float>();
|
||||
AscendC::Duplicate(ones_tensor, (float)1, size_);
|
||||
AscendC::LocalTensor<float> maskLocal_fp32;
|
||||
|
||||
AscendC::Sub(maskLocal_fp32,
|
||||
ones_tensor,
|
||||
vocabMask_fp32,
|
||||
size_);
|
||||
AscendC::Sub(maskLocal_fp32, ones_tensor, vocabMask_fp32, size_);
|
||||
|
||||
AscendC::LocalTensor<half> maskLocal_fp16;
|
||||
AscendC::Cast(maskLocal_fp16, maskLocal_fp32, AscendC::RoundMode::CAST_NONE, size_);
|
||||
@ -262,8 +298,6 @@ private:
|
||||
// Temporary buffers
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> start_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> end_buf;
|
||||
|
||||
// Temporary buffers continued
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> inputFloat_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> validOffset_buf;
|
||||
AscendC::TBuf<AscendC::TPosition::VECCALC> vocabMask_buf_;
|
||||
@ -342,4 +376,3 @@ void get_masked_input_and_mask_impl(
|
||||
}
|
||||
|
||||
} // namespace vllm_ascend
|
||||
|
||||
|
@ -30,7 +30,11 @@ using vllm_ascend::local_mem_copy;
|
||||
template <typename scalar_t, bool isNeox> class RotaryEmbedding {
|
||||
// NOTE(ganyi): we use 512B as load stride for pipe, need to find another way to
|
||||
// retrieve this size from runtime for more Soc support
|
||||
static int constexpr loadSize = 512;
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
static int constexpr loadSize = 512;
|
||||
#else
|
||||
static int constexpr loadSize = 1024 * 4;
|
||||
#endif
|
||||
using dst_t = scalar_t;
|
||||
using acc_t = typename AccType<scalar_t>::type;
|
||||
// only half tensor have cast instruct to int8, hardcode acc_dst_t as half
|
||||
@ -326,7 +330,9 @@ private:
|
||||
|
||||
// Declare all the kernel entry here
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(half)
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
ROPE_CUSTOM_KERNEL_DECLARE(bfloat16_t)
|
||||
#endif
|
||||
|
||||
namespace vllm_ascend {
|
||||
|
||||
@ -342,7 +348,7 @@ namespace vllm_ascend {
|
||||
reinterpret_cast<TYPE *>(cosSinCache), rotDim, queryStride, keyStride, dstQueryStride, dstKeyStride, \
|
||||
numHeads, numKvHeads, headSize, numTokens, loopCnt, blockDim);
|
||||
|
||||
// maximum number for runtime to launch a ascendc kernel.
|
||||
// maximum number for runtime to launch a ascendc kernel.
|
||||
// we use this to constrain the maximum number of block size
|
||||
static const int64_t maxParallelSize = 65535;
|
||||
|
||||
@ -357,9 +363,13 @@ extern void rotary_embedding_impl(AscendType type, bool isNeox, void *stream, in
|
||||
int blockDim = maxParallelSize > numTokens ? numTokens : maxParallelSize;
|
||||
if (type == AscendType::FP16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(half);
|
||||
} else if (type == AscendType::BF16) {
|
||||
}
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
else if (type == AscendType::BF16) {
|
||||
ROTARY_EMBEDDING_KERNEL_CALL(bfloat16_t);
|
||||
} else {
|
||||
}
|
||||
#endif
|
||||
else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -20,9 +20,11 @@ namespace vllm_ascend {
|
||||
|
||||
template <typename scalar_t> struct AccType;
|
||||
|
||||
#if (__CCE_AICORE__ >= 220)
|
||||
template <> struct AccType<bfloat16_t> {
|
||||
using type = float;
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <> struct AccType<half> {
|
||||
using type = half;
|
||||
|
Reference in New Issue
Block a user