From a6956f00d470347c61dc28a2050b34d8aa2ac0cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=81=A510355098?= Date: Thu, 9 Oct 2025 14:37:15 +0800 Subject: [PATCH] cpu: rv64: eltwise: fix: remove early pointer casting --- src/cpu/rv64/rvv_eltwise.cpp | 35 ++--- src/cpu/rv64/rvv_eltwise_kernels.hpp | 197 +++++++++++++++++---------- 2 files changed, 138 insertions(+), 94 deletions(-) diff --git a/src/cpu/rv64/rvv_eltwise.cpp b/src/cpu/rv64/rvv_eltwise.cpp index 539140b1da..3888742b80 100644 --- a/src/cpu/rv64/rvv_eltwise.cpp +++ b/src/cpu/rv64/rvv_eltwise.cpp @@ -36,22 +36,16 @@ static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg, const dim_t len, const data_type_t dt) { switch (dt) { case data_type::f32: - rvv_eltwise_apply_fwd_f32(alg, reinterpret_cast(src), - reinterpret_cast(dst), len, alpha, beta); + rvv_eltwise_apply_fwd_f32(alg, src, dst, len, alpha, beta, dt); break; case data_type::s32: - rvv_eltwise_apply_fwd_s32(alg, - reinterpret_cast(src), - reinterpret_cast(dst), len, alpha, beta); + rvv_eltwise_apply_fwd_s32(alg, src, dst, len, alpha, beta, dt); break; case data_type::s8: - rvv_eltwise_apply_fwd_s8(alg, reinterpret_cast(src), - reinterpret_cast(dst), len, alpha, beta); + rvv_eltwise_apply_fwd_s8(alg, src, dst, len, alpha, beta, dt); break; case data_type::u8: - rvv_eltwise_apply_fwd_u8(alg, - reinterpret_cast(src), - reinterpret_cast(dst), len, alpha, beta); + rvv_eltwise_apply_fwd_u8(alg, src, dst, len, alpha, beta, dt); break; default: assert(!"Unsupported data type for RVV eltwise"); } @@ -63,25 +57,20 @@ static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src, const float beta, const dim_t len, const data_type_t dt) { switch (dt) { case data_type::f32: - rvv_eltwise_apply_bwd_f32(alg, reinterpret_cast(diff_src), - reinterpret_cast(diff_dst), - reinterpret_cast(src), len, alpha, beta); + rvv_eltwise_apply_bwd_f32( + alg, diff_src, diff_dst, src, len, alpha, beta, dt); break; case data_type::s32: - rvv_eltwise_apply_bwd_s32(alg, - reinterpret_cast(diff_src), - reinterpret_cast(diff_dst), - reinterpret_cast(src), len, alpha, beta); + rvv_eltwise_apply_bwd_s32( + alg, diff_src, diff_dst, src, len, alpha, beta, dt); break; case data_type::s8: - rvv_eltwise_apply_bwd_s8(alg, reinterpret_cast(diff_src), - reinterpret_cast(diff_dst), - reinterpret_cast(src), len, alpha, beta); + rvv_eltwise_apply_bwd_s8( + alg, diff_src, diff_dst, src, len, alpha, beta, dt); break; case data_type::u8: - rvv_eltwise_apply_bwd_u8(alg, reinterpret_cast(diff_src), - reinterpret_cast(diff_dst), - reinterpret_cast(src), len, alpha, beta); + rvv_eltwise_apply_bwd_u8( + alg, diff_src, diff_dst, src, len, alpha, beta, dt); break; default: assert(!"Unsupported data type for RVV eltwise"); } diff --git a/src/cpu/rv64/rvv_eltwise_kernels.hpp b/src/cpu/rv64/rvv_eltwise_kernels.hpp index 9856c21308..91d4c02c41 100644 --- a/src/cpu/rv64/rvv_eltwise_kernels.hpp +++ b/src/cpu/rv64/rvv_eltwise_kernels.hpp @@ -47,43 +47,67 @@ using eval_bwd_u8_fn_t = vuint8m1_t (*)(vuint8m1_t, vuint8m1_t, float, float, size_t); /*** Kernels for forward pass ***/ -inline void rvv_eltwise_fwd_kernel_f32(const float *src, float *dst, dim_t len, - float alpha, float beta, eval_fwd_f32_fn_t eval) { +inline void rvv_eltwise_fwd_kernel_f32(const void *src_base, void *dst_base, + dim_t len, float alpha, float beta, eval_fwd_f32_fn_t eval, + const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e32m1(static_cast(len - i)); - vfloat32m1_t vin = __riscv_vle32_v_f32m1(src + i, vl); + const float *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + float *dst = reinterpret_cast( + static_cast(dst_base) + i * types::data_type_size(dt)); + vfloat32m1_t vin = __riscv_vle32_v_f32m1(src, vl); vfloat32m1_t vout = eval(vin, alpha, beta, vl); - __riscv_vse32_v_f32m1(dst + i, vout, vl); + __riscv_vse32_v_f32m1(dst, vout, vl); i += static_cast(vl); } } -inline void rvv_eltwise_fwd_kernel_s32(const int32_t *src, int32_t *dst, - dim_t len, float alpha, float beta, eval_fwd_s32_fn_t eval) { +inline void rvv_eltwise_fwd_kernel_s32(const void *src_base, void *dst_base, + dim_t len, float alpha, float beta, eval_fwd_s32_fn_t eval, + const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e32m1(static_cast(len - i)); - vint32m1_t vin = __riscv_vle32_v_i32m1(src + i, vl); + const int32_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + int32_t *dst = reinterpret_cast( + static_cast(dst_base) + i * types::data_type_size(dt)); + vint32m1_t vin = __riscv_vle32_v_i32m1(src, vl); vint32m1_t vout = eval(vin, alpha, beta, vl); - __riscv_vse32_v_i32m1(dst + i, vout, vl); + __riscv_vse32_v_i32m1(dst, vout, vl); i += static_cast(vl); } } -inline void rvv_eltwise_fwd_kernel_s8(const int8_t *src, int8_t *dst, dim_t len, - float alpha, float beta, eval_fwd_s8_fn_t eval) { +inline void rvv_eltwise_fwd_kernel_s8(const void *src_base, void *dst_base, + dim_t len, float alpha, float beta, eval_fwd_s8_fn_t eval, + const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e8m1(static_cast(len - i)); - vint8m1_t vin = __riscv_vle8_v_i8m1(src + i, vl); + const int8_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + int8_t *dst = reinterpret_cast( + static_cast(dst_base) + i * types::data_type_size(dt)); + vint8m1_t vin = __riscv_vle8_v_i8m1(src, vl); vint8m1_t vout = eval(vin, alpha, beta, vl); - __riscv_vse8_v_i8m1(dst + i, vout, vl); + __riscv_vse8_v_i8m1(dst, vout, vl); i += static_cast(vl); } } -inline void rvv_eltwise_fwd_kernel_u8(const uint8_t *src, uint8_t *dst, - dim_t len, float alpha, float beta, eval_fwd_u8_fn_t eval) { +inline void rvv_eltwise_fwd_kernel_u8(const void *src_base, void *dst_base, + dim_t len, float alpha, float beta, eval_fwd_u8_fn_t eval, + const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e8m1(static_cast(len - i)); - vuint8m1_t vin = __riscv_vle8_v_u8m1(src + i, vl); + const uint8_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + uint8_t *dst = reinterpret_cast( + static_cast(dst_base) + i * types::data_type_size(dt)); + vuint8m1_t vin = __riscv_vle8_v_u8m1(src, vl); vuint8m1_t vout = eval(vin, alpha, beta, vl); - __riscv_vse8_v_u8m1(dst + i, vout, vl); + __riscv_vse8_v_u8m1(dst, vout, vl); i += static_cast(vl); } } @@ -411,93 +435,120 @@ inline eval_fwd_u8_fn_t get_eval_fwd_u8(alg_kind_t alg) { } /*** Apply methods for forward pass ***/ -inline void rvv_eltwise_apply_fwd_f32(alg_kind_t alg, const float *src, - float *dst, dim_t len, float alpha, float beta) { +inline void rvv_eltwise_apply_fwd_f32(alg_kind_t alg, const void *src, + void *dst, dim_t len, float alpha, float beta, const data_type_t dt) { auto eval = get_eval_fwd_f32(alg); if (!eval) { assert(!"[rvv_eltwise_apply_fwd_f32] unknown eltwise alg_kind"); return; } - rvv_eltwise_fwd_kernel_f32(src, dst, len, alpha, beta, eval); + rvv_eltwise_fwd_kernel_f32(src, dst, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_fwd_s32(alg_kind_t alg, const int32_t *src, - int32_t *dst, dim_t len, float alpha, float beta) { +inline void rvv_eltwise_apply_fwd_s32(alg_kind_t alg, const void *src, + void *dst, dim_t len, float alpha, float beta, const data_type_t dt) { auto eval = get_eval_fwd_s32(alg); if (!eval) { assert(!"[rvv_eltwise_apply_fwd_s32] unknown eltwise alg_kind"); return; } - rvv_eltwise_fwd_kernel_s32(src, dst, len, alpha, beta, eval); + rvv_eltwise_fwd_kernel_s32(src, dst, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_fwd_s8(alg_kind_t alg, const int8_t *src, - int8_t *dst, dim_t len, float alpha, float beta) { +inline void rvv_eltwise_apply_fwd_s8(alg_kind_t alg, const void *src, void *dst, + dim_t len, float alpha, float beta, const data_type_t dt) { auto eval = get_eval_fwd_s8(alg); if (!eval) { assert(!"[rvv_eltwise_apply_fwd_s8] unknown eltwise alg_kind"); return; } - rvv_eltwise_fwd_kernel_s8(src, dst, len, alpha, beta, eval); + rvv_eltwise_fwd_kernel_s8(src, dst, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_fwd_u8(alg_kind_t alg, const uint8_t *src, - uint8_t *dst, dim_t len, float alpha, float beta) { +inline void rvv_eltwise_apply_fwd_u8(alg_kind_t alg, const void *src, void *dst, + dim_t len, float alpha, float beta, const data_type_t dt) { auto eval = get_eval_fwd_u8(alg); if (!eval) { assert(!"[rvv_eltwise_apply_fwd_u8] unknown eltwise alg_kind"); return; } - rvv_eltwise_fwd_kernel_u8(src, dst, len, alpha, beta, eval); + rvv_eltwise_fwd_kernel_u8(src, dst, len, alpha, beta, eval, dt); } /* --- Backward pass --- */ -// For backward pass, we need to compute the gradient of the loss with respect to the input -// and the parameters. Thus, diff_src is the output, diff_dst and src are the inputs. - /*** Kernels for backward pass ***/ -inline void rvv_eltwise_bwd_kernel_f32(float *diff_src, const float *diff_dst, - const float *src, dim_t len, float alpha, float beta, - eval_bwd_f32_fn_t eval) { +inline void rvv_eltwise_bwd_kernel_f32(void *diff_src, const void *diff_dst, + const void *src, dim_t len, float alpha, float beta, + eval_bwd_f32_fn_t eval, const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e32m1(static_cast(len - i)); - vfloat32m1_t vdiff_dst = __riscv_vle32_v_f32m1(diff_dst + i, vl); - vfloat32m1_t vsrc = __riscv_vle32_v_f32m1(src + i, vl); + vfloat32m1_t vdiff_dst = __riscv_vle32_v_f32m1( + reinterpret_cast(diff_dst) + i, vl); + vfloat32m1_t vsrc = __riscv_vle32_v_f32m1( + reinterpret_cast(src) + i, vl); vfloat32m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl); - __riscv_vse32_v_f32m1(diff_src + i, vdiff_src, vl); + __riscv_vse32_v_f32m1( + reinterpret_cast(diff_src) + i, vdiff_src, vl); i += static_cast(vl); } } -inline void rvv_eltwise_bwd_kernel_s32(int32_t *diff_src, - const int32_t *diff_dst, const int32_t *src, dim_t len, float alpha, - float beta, eval_bwd_s32_fn_t eval) { +inline void rvv_eltwise_bwd_kernel_s32(void *diff_src_base, + const void *diff_dst_base, const void *src_base, dim_t len, float alpha, + float beta, eval_bwd_s32_fn_t eval, const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e32m1(static_cast(len - i)); - vint32m1_t vdiff_dst = __riscv_vle32_v_i32m1(diff_dst + i, vl); - vint32m1_t vsrc = __riscv_vle32_v_i32m1(src + i, vl); + const int32_t *diff_dst = reinterpret_cast( + static_cast(diff_dst_base) + + i * types::data_type_size(dt)); + const int32_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + int32_t *diff_src + = reinterpret_cast(static_cast(diff_src_base) + + i * types::data_type_size(dt)); + vint32m1_t vdiff_dst = __riscv_vle32_v_i32m1(diff_dst, vl); + vint32m1_t vsrc = __riscv_vle32_v_i32m1(src, vl); vint32m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl); - __riscv_vse32_v_i32m1(diff_src + i, vdiff_src, vl); + __riscv_vse32_v_i32m1(diff_src, vdiff_src, vl); i += static_cast(vl); } } -inline void rvv_eltwise_bwd_kernel_s8(int8_t *diff_src, const int8_t *diff_dst, - const int8_t *src, dim_t len, float alpha, float beta, - eval_bwd_s8_fn_t eval) { +inline void rvv_eltwise_bwd_kernel_s8(void *diff_src_base, + const void *diff_dst_base, const void *src_base, dim_t len, float alpha, + float beta, eval_bwd_s8_fn_t eval, const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e8m1(static_cast(len - i)); - vint8m1_t vdiff_dst = __riscv_vle8_v_i8m1(diff_dst + i, vl); - vint8m1_t vsrc = __riscv_vle8_v_i8m1(src + i, vl); + const int8_t *diff_dst = reinterpret_cast( + static_cast(diff_dst_base) + + i * types::data_type_size(dt)); + const int8_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + int8_t *diff_src + = reinterpret_cast(static_cast(diff_src_base) + + i * types::data_type_size(dt)); + vint8m1_t vdiff_dst = __riscv_vle8_v_i8m1(diff_dst, vl); + vint8m1_t vsrc = __riscv_vle8_v_i8m1(src, vl); vint8m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl); - __riscv_vse8_v_i8m1(diff_src + i, vdiff_src, vl); + __riscv_vse8_v_i8m1(diff_src, vdiff_src, vl); i += static_cast(vl); } } -inline void rvv_eltwise_bwd_kernel_u8(uint8_t *diff_src, - const uint8_t *diff_dst, const uint8_t *src, dim_t len, float alpha, - float beta, eval_bwd_u8_fn_t eval) { +inline void rvv_eltwise_bwd_kernel_u8(void *diff_src_base, + const void *diff_dst_base, const void *src_base, dim_t len, float alpha, + float beta, eval_bwd_u8_fn_t eval, const data_type_t dt) { for (dim_t i = 0; i < len;) { size_t vl = __riscv_vsetvl_e8m1(static_cast(len - i)); - vuint8m1_t vdiff_dst = __riscv_vle8_v_u8m1(diff_dst + i, vl); - vuint8m1_t vsrc = __riscv_vle8_v_u8m1(src + i, vl); + const uint8_t *diff_dst = reinterpret_cast( + static_cast(diff_dst_base) + + i * types::data_type_size(dt)); + const uint8_t *src = reinterpret_cast( + static_cast(src_base) + + i * types::data_type_size(dt)); + uint8_t *diff_src + = reinterpret_cast(static_cast(diff_src_base) + + i * types::data_type_size(dt)); + vuint8m1_t vdiff_dst = __riscv_vle8_v_u8m1(diff_dst, vl); + vuint8m1_t vsrc = __riscv_vle8_v_u8m1(src, vl); vuint8m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl); - __riscv_vse8_v_u8m1(diff_src + i, vdiff_src, vl); + __riscv_vse8_v_u8m1(diff_src, vdiff_src, vl); i += static_cast(vl); } } @@ -920,45 +971,49 @@ inline eval_bwd_u8_fn_t get_eval_bwd_u8(alg_kind_t alg) { } /*** Apply methods for backward pass ***/ -inline void rvv_eltwise_apply_bwd_f32(alg_kind_t alg, float *diff_src, - const float *diff_dst, const float *src, dim_t len, float alpha, - float beta) { +inline void rvv_eltwise_apply_bwd_f32(alg_kind_t alg, void *diff_src, + const void *diff_dst, const void *src, dim_t len, float alpha, + float beta, const data_type_t dt) { auto eval = get_eval_bwd_f32(alg); if (!eval) { assert(!"[rvv_eltwise_apply_bwd_f32] unknown eltwise alg_kind"); return; } - rvv_eltwise_bwd_kernel_f32(diff_src, diff_dst, src, len, alpha, beta, eval); + rvv_eltwise_bwd_kernel_f32( + diff_src, diff_dst, src, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_bwd_s32(alg_kind_t alg, int32_t *diff_src, - const int32_t *diff_dst, const int32_t *src, dim_t len, float alpha, - float beta) { +inline void rvv_eltwise_apply_bwd_s32(alg_kind_t alg, void *diff_src, + const void *diff_dst, const void *src, dim_t len, float alpha, + float beta, const data_type_t dt) { auto eval = get_eval_bwd_s32(alg); if (!eval) { assert(!"[rvv_eltwise_apply_bwd_s32] unknown eltwise alg_kind"); return; } - rvv_eltwise_bwd_kernel_s32(diff_src, diff_dst, src, len, alpha, beta, eval); + rvv_eltwise_bwd_kernel_s32( + diff_src, diff_dst, src, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_bwd_s8(alg_kind_t alg, int8_t *diff_src, - const int8_t *diff_dst, const int8_t *src, dim_t len, float alpha, - float beta) { +inline void rvv_eltwise_apply_bwd_s8(alg_kind_t alg, void *diff_src, + const void *diff_dst, const void *src, dim_t len, float alpha, + float beta, const data_type_t dt) { auto eval = get_eval_bwd_s8(alg); if (!eval) { assert(!"[rvv_eltwise_apply_bwd_s8] unknown eltwise alg_kind"); return; } - rvv_eltwise_bwd_kernel_s8(diff_src, diff_dst, src, len, alpha, beta, eval); + rvv_eltwise_bwd_kernel_s8( + diff_src, diff_dst, src, len, alpha, beta, eval, dt); } -inline void rvv_eltwise_apply_bwd_u8(alg_kind_t alg, uint8_t *diff_src, - const uint8_t *diff_dst, const uint8_t *src, dim_t len, float alpha, - float beta) { +inline void rvv_eltwise_apply_bwd_u8(alg_kind_t alg, void *diff_src, + const void *diff_dst, const void *src, dim_t len, float alpha, + float beta, const data_type_t dt) { auto eval = get_eval_bwd_u8(alg); if (!eval) { assert(!"[rvv_eltwise_apply_bwd_u8] unknown eltwise alg_kind"); return; } - rvv_eltwise_bwd_kernel_u8(diff_src, diff_dst, src, len, alpha, beta, eval); + rvv_eltwise_bwd_kernel_u8( + diff_src, diff_dst, src, len, alpha, beta, eval, dt); } } // namespace rv64