mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
cpu: rv64: eltwise: fix: remove early pointer casting
This commit is contained in:
committed by
Vadim Pirogov
parent
944c782e17
commit
a6956f00d4
@ -36,22 +36,16 @@ static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
|
||||
const dim_t len, const data_type_t dt) {
|
||||
switch (dt) {
|
||||
case data_type::f32:
|
||||
rvv_eltwise_apply_fwd_f32(alg, reinterpret_cast<const float *>(src),
|
||||
reinterpret_cast<float *>(dst), len, alpha, beta);
|
||||
rvv_eltwise_apply_fwd_f32(alg, src, dst, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::s32:
|
||||
rvv_eltwise_apply_fwd_s32(alg,
|
||||
reinterpret_cast<const int32_t *>(src),
|
||||
reinterpret_cast<int32_t *>(dst), len, alpha, beta);
|
||||
rvv_eltwise_apply_fwd_s32(alg, src, dst, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::s8:
|
||||
rvv_eltwise_apply_fwd_s8(alg, reinterpret_cast<const int8_t *>(src),
|
||||
reinterpret_cast<int8_t *>(dst), len, alpha, beta);
|
||||
rvv_eltwise_apply_fwd_s8(alg, src, dst, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::u8:
|
||||
rvv_eltwise_apply_fwd_u8(alg,
|
||||
reinterpret_cast<const uint8_t *>(src),
|
||||
reinterpret_cast<uint8_t *>(dst), len, alpha, beta);
|
||||
rvv_eltwise_apply_fwd_u8(alg, src, dst, len, alpha, beta, dt);
|
||||
break;
|
||||
default: assert(!"Unsupported data type for RVV eltwise");
|
||||
}
|
||||
@ -63,25 +57,20 @@ static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
const float beta, const dim_t len, const data_type_t dt) {
|
||||
switch (dt) {
|
||||
case data_type::f32:
|
||||
rvv_eltwise_apply_bwd_f32(alg, reinterpret_cast<float *>(diff_src),
|
||||
reinterpret_cast<const float *>(diff_dst),
|
||||
reinterpret_cast<const float *>(src), len, alpha, beta);
|
||||
rvv_eltwise_apply_bwd_f32(
|
||||
alg, diff_src, diff_dst, src, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::s32:
|
||||
rvv_eltwise_apply_bwd_s32(alg,
|
||||
reinterpret_cast<int32_t *>(diff_src),
|
||||
reinterpret_cast<const int32_t *>(diff_dst),
|
||||
reinterpret_cast<const int32_t *>(src), len, alpha, beta);
|
||||
rvv_eltwise_apply_bwd_s32(
|
||||
alg, diff_src, diff_dst, src, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::s8:
|
||||
rvv_eltwise_apply_bwd_s8(alg, reinterpret_cast<int8_t *>(diff_src),
|
||||
reinterpret_cast<const int8_t *>(diff_dst),
|
||||
reinterpret_cast<const int8_t *>(src), len, alpha, beta);
|
||||
rvv_eltwise_apply_bwd_s8(
|
||||
alg, diff_src, diff_dst, src, len, alpha, beta, dt);
|
||||
break;
|
||||
case data_type::u8:
|
||||
rvv_eltwise_apply_bwd_u8(alg, reinterpret_cast<uint8_t *>(diff_src),
|
||||
reinterpret_cast<const uint8_t *>(diff_dst),
|
||||
reinterpret_cast<const uint8_t *>(src), len, alpha, beta);
|
||||
rvv_eltwise_apply_bwd_u8(
|
||||
alg, diff_src, diff_dst, src, len, alpha, beta, dt);
|
||||
break;
|
||||
default: assert(!"Unsupported data type for RVV eltwise");
|
||||
}
|
||||
|
@ -47,43 +47,67 @@ using eval_bwd_u8_fn_t
|
||||
= vuint8m1_t (*)(vuint8m1_t, vuint8m1_t, float, float, size_t);
|
||||
|
||||
/*** Kernels for forward pass ***/
|
||||
inline void rvv_eltwise_fwd_kernel_f32(const float *src, float *dst, dim_t len,
|
||||
float alpha, float beta, eval_fwd_f32_fn_t eval) {
|
||||
inline void rvv_eltwise_fwd_kernel_f32(const void *src_base, void *dst_base,
|
||||
dim_t len, float alpha, float beta, eval_fwd_f32_fn_t eval,
|
||||
const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e32m1(static_cast<size_t>(len - i));
|
||||
vfloat32m1_t vin = __riscv_vle32_v_f32m1(src + i, vl);
|
||||
const float *src = reinterpret_cast<const float *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
float *dst = reinterpret_cast<float *>(
|
||||
static_cast<char *>(dst_base) + i * types::data_type_size(dt));
|
||||
vfloat32m1_t vin = __riscv_vle32_v_f32m1(src, vl);
|
||||
vfloat32m1_t vout = eval(vin, alpha, beta, vl);
|
||||
__riscv_vse32_v_f32m1(dst + i, vout, vl);
|
||||
__riscv_vse32_v_f32m1(dst, vout, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_fwd_kernel_s32(const int32_t *src, int32_t *dst,
|
||||
dim_t len, float alpha, float beta, eval_fwd_s32_fn_t eval) {
|
||||
inline void rvv_eltwise_fwd_kernel_s32(const void *src_base, void *dst_base,
|
||||
dim_t len, float alpha, float beta, eval_fwd_s32_fn_t eval,
|
||||
const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e32m1(static_cast<size_t>(len - i));
|
||||
vint32m1_t vin = __riscv_vle32_v_i32m1(src + i, vl);
|
||||
const int32_t *src = reinterpret_cast<const int32_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
int32_t *dst = reinterpret_cast<int32_t *>(
|
||||
static_cast<char *>(dst_base) + i * types::data_type_size(dt));
|
||||
vint32m1_t vin = __riscv_vle32_v_i32m1(src, vl);
|
||||
vint32m1_t vout = eval(vin, alpha, beta, vl);
|
||||
__riscv_vse32_v_i32m1(dst + i, vout, vl);
|
||||
__riscv_vse32_v_i32m1(dst, vout, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_fwd_kernel_s8(const int8_t *src, int8_t *dst, dim_t len,
|
||||
float alpha, float beta, eval_fwd_s8_fn_t eval) {
|
||||
inline void rvv_eltwise_fwd_kernel_s8(const void *src_base, void *dst_base,
|
||||
dim_t len, float alpha, float beta, eval_fwd_s8_fn_t eval,
|
||||
const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e8m1(static_cast<size_t>(len - i));
|
||||
vint8m1_t vin = __riscv_vle8_v_i8m1(src + i, vl);
|
||||
const int8_t *src = reinterpret_cast<const int8_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
int8_t *dst = reinterpret_cast<int8_t *>(
|
||||
static_cast<char *>(dst_base) + i * types::data_type_size(dt));
|
||||
vint8m1_t vin = __riscv_vle8_v_i8m1(src, vl);
|
||||
vint8m1_t vout = eval(vin, alpha, beta, vl);
|
||||
__riscv_vse8_v_i8m1(dst + i, vout, vl);
|
||||
__riscv_vse8_v_i8m1(dst, vout, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_fwd_kernel_u8(const uint8_t *src, uint8_t *dst,
|
||||
dim_t len, float alpha, float beta, eval_fwd_u8_fn_t eval) {
|
||||
inline void rvv_eltwise_fwd_kernel_u8(const void *src_base, void *dst_base,
|
||||
dim_t len, float alpha, float beta, eval_fwd_u8_fn_t eval,
|
||||
const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e8m1(static_cast<size_t>(len - i));
|
||||
vuint8m1_t vin = __riscv_vle8_v_u8m1(src + i, vl);
|
||||
const uint8_t *src = reinterpret_cast<const uint8_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
uint8_t *dst = reinterpret_cast<uint8_t *>(
|
||||
static_cast<char *>(dst_base) + i * types::data_type_size(dt));
|
||||
vuint8m1_t vin = __riscv_vle8_v_u8m1(src, vl);
|
||||
vuint8m1_t vout = eval(vin, alpha, beta, vl);
|
||||
__riscv_vse8_v_u8m1(dst + i, vout, vl);
|
||||
__riscv_vse8_v_u8m1(dst, vout, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
@ -411,93 +435,120 @@ inline eval_fwd_u8_fn_t get_eval_fwd_u8(alg_kind_t alg) {
|
||||
}
|
||||
|
||||
/*** Apply methods for forward pass ***/
|
||||
inline void rvv_eltwise_apply_fwd_f32(alg_kind_t alg, const float *src,
|
||||
float *dst, dim_t len, float alpha, float beta) {
|
||||
inline void rvv_eltwise_apply_fwd_f32(alg_kind_t alg, const void *src,
|
||||
void *dst, dim_t len, float alpha, float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_fwd_f32(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_fwd_f32] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_fwd_kernel_f32(src, dst, len, alpha, beta, eval);
|
||||
rvv_eltwise_fwd_kernel_f32(src, dst, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_fwd_s32(alg_kind_t alg, const int32_t *src,
|
||||
int32_t *dst, dim_t len, float alpha, float beta) {
|
||||
inline void rvv_eltwise_apply_fwd_s32(alg_kind_t alg, const void *src,
|
||||
void *dst, dim_t len, float alpha, float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_fwd_s32(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_fwd_s32] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_fwd_kernel_s32(src, dst, len, alpha, beta, eval);
|
||||
rvv_eltwise_fwd_kernel_s32(src, dst, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_fwd_s8(alg_kind_t alg, const int8_t *src,
|
||||
int8_t *dst, dim_t len, float alpha, float beta) {
|
||||
inline void rvv_eltwise_apply_fwd_s8(alg_kind_t alg, const void *src, void *dst,
|
||||
dim_t len, float alpha, float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_fwd_s8(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_fwd_s8] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_fwd_kernel_s8(src, dst, len, alpha, beta, eval);
|
||||
rvv_eltwise_fwd_kernel_s8(src, dst, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_fwd_u8(alg_kind_t alg, const uint8_t *src,
|
||||
uint8_t *dst, dim_t len, float alpha, float beta) {
|
||||
inline void rvv_eltwise_apply_fwd_u8(alg_kind_t alg, const void *src, void *dst,
|
||||
dim_t len, float alpha, float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_fwd_u8(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_fwd_u8] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_fwd_kernel_u8(src, dst, len, alpha, beta, eval);
|
||||
rvv_eltwise_fwd_kernel_u8(src, dst, len, alpha, beta, eval, dt);
|
||||
}
|
||||
|
||||
/* --- Backward pass --- */
|
||||
// For backward pass, we need to compute the gradient of the loss with respect to the input
|
||||
// and the parameters. Thus, diff_src is the output, diff_dst and src are the inputs.
|
||||
|
||||
/*** Kernels for backward pass ***/
|
||||
inline void rvv_eltwise_bwd_kernel_f32(float *diff_src, const float *diff_dst,
|
||||
const float *src, dim_t len, float alpha, float beta,
|
||||
eval_bwd_f32_fn_t eval) {
|
||||
inline void rvv_eltwise_bwd_kernel_f32(void *diff_src, const void *diff_dst,
|
||||
const void *src, dim_t len, float alpha, float beta,
|
||||
eval_bwd_f32_fn_t eval, const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e32m1(static_cast<size_t>(len - i));
|
||||
vfloat32m1_t vdiff_dst = __riscv_vle32_v_f32m1(diff_dst + i, vl);
|
||||
vfloat32m1_t vsrc = __riscv_vle32_v_f32m1(src + i, vl);
|
||||
vfloat32m1_t vdiff_dst = __riscv_vle32_v_f32m1(
|
||||
reinterpret_cast<const float *>(diff_dst) + i, vl);
|
||||
vfloat32m1_t vsrc = __riscv_vle32_v_f32m1(
|
||||
reinterpret_cast<const float *>(src) + i, vl);
|
||||
vfloat32m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl);
|
||||
__riscv_vse32_v_f32m1(diff_src + i, vdiff_src, vl);
|
||||
__riscv_vse32_v_f32m1(
|
||||
reinterpret_cast<float *>(diff_src) + i, vdiff_src, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_bwd_kernel_s32(int32_t *diff_src,
|
||||
const int32_t *diff_dst, const int32_t *src, dim_t len, float alpha,
|
||||
float beta, eval_bwd_s32_fn_t eval) {
|
||||
inline void rvv_eltwise_bwd_kernel_s32(void *diff_src_base,
|
||||
const void *diff_dst_base, const void *src_base, dim_t len, float alpha,
|
||||
float beta, eval_bwd_s32_fn_t eval, const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e32m1(static_cast<size_t>(len - i));
|
||||
vint32m1_t vdiff_dst = __riscv_vle32_v_i32m1(diff_dst + i, vl);
|
||||
vint32m1_t vsrc = __riscv_vle32_v_i32m1(src + i, vl);
|
||||
const int32_t *diff_dst = reinterpret_cast<const int32_t *>(
|
||||
static_cast<const char *>(diff_dst_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
const int32_t *src = reinterpret_cast<const int32_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
int32_t *diff_src
|
||||
= reinterpret_cast<int32_t *>(static_cast<char *>(diff_src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
vint32m1_t vdiff_dst = __riscv_vle32_v_i32m1(diff_dst, vl);
|
||||
vint32m1_t vsrc = __riscv_vle32_v_i32m1(src, vl);
|
||||
vint32m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl);
|
||||
__riscv_vse32_v_i32m1(diff_src + i, vdiff_src, vl);
|
||||
__riscv_vse32_v_i32m1(diff_src, vdiff_src, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_bwd_kernel_s8(int8_t *diff_src, const int8_t *diff_dst,
|
||||
const int8_t *src, dim_t len, float alpha, float beta,
|
||||
eval_bwd_s8_fn_t eval) {
|
||||
inline void rvv_eltwise_bwd_kernel_s8(void *diff_src_base,
|
||||
const void *diff_dst_base, const void *src_base, dim_t len, float alpha,
|
||||
float beta, eval_bwd_s8_fn_t eval, const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e8m1(static_cast<size_t>(len - i));
|
||||
vint8m1_t vdiff_dst = __riscv_vle8_v_i8m1(diff_dst + i, vl);
|
||||
vint8m1_t vsrc = __riscv_vle8_v_i8m1(src + i, vl);
|
||||
const int8_t *diff_dst = reinterpret_cast<const int8_t *>(
|
||||
static_cast<const char *>(diff_dst_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
const int8_t *src = reinterpret_cast<const int8_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
int8_t *diff_src
|
||||
= reinterpret_cast<int8_t *>(static_cast<char *>(diff_src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
vint8m1_t vdiff_dst = __riscv_vle8_v_i8m1(diff_dst, vl);
|
||||
vint8m1_t vsrc = __riscv_vle8_v_i8m1(src, vl);
|
||||
vint8m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl);
|
||||
__riscv_vse8_v_i8m1(diff_src + i, vdiff_src, vl);
|
||||
__riscv_vse8_v_i8m1(diff_src, vdiff_src, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
inline void rvv_eltwise_bwd_kernel_u8(uint8_t *diff_src,
|
||||
const uint8_t *diff_dst, const uint8_t *src, dim_t len, float alpha,
|
||||
float beta, eval_bwd_u8_fn_t eval) {
|
||||
inline void rvv_eltwise_bwd_kernel_u8(void *diff_src_base,
|
||||
const void *diff_dst_base, const void *src_base, dim_t len, float alpha,
|
||||
float beta, eval_bwd_u8_fn_t eval, const data_type_t dt) {
|
||||
for (dim_t i = 0; i < len;) {
|
||||
size_t vl = __riscv_vsetvl_e8m1(static_cast<size_t>(len - i));
|
||||
vuint8m1_t vdiff_dst = __riscv_vle8_v_u8m1(diff_dst + i, vl);
|
||||
vuint8m1_t vsrc = __riscv_vle8_v_u8m1(src + i, vl);
|
||||
const uint8_t *diff_dst = reinterpret_cast<const uint8_t *>(
|
||||
static_cast<const char *>(diff_dst_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
const uint8_t *src = reinterpret_cast<const uint8_t *>(
|
||||
static_cast<const char *>(src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
uint8_t *diff_src
|
||||
= reinterpret_cast<uint8_t *>(static_cast<char *>(diff_src_base)
|
||||
+ i * types::data_type_size(dt));
|
||||
vuint8m1_t vdiff_dst = __riscv_vle8_v_u8m1(diff_dst, vl);
|
||||
vuint8m1_t vsrc = __riscv_vle8_v_u8m1(src, vl);
|
||||
vuint8m1_t vdiff_src = eval(vdiff_dst, vsrc, alpha, beta, vl);
|
||||
__riscv_vse8_v_u8m1(diff_src + i, vdiff_src, vl);
|
||||
__riscv_vse8_v_u8m1(diff_src, vdiff_src, vl);
|
||||
i += static_cast<dim_t>(vl);
|
||||
}
|
||||
}
|
||||
@ -920,45 +971,49 @@ inline eval_bwd_u8_fn_t get_eval_bwd_u8(alg_kind_t alg) {
|
||||
}
|
||||
|
||||
/*** Apply methods for backward pass ***/
|
||||
inline void rvv_eltwise_apply_bwd_f32(alg_kind_t alg, float *diff_src,
|
||||
const float *diff_dst, const float *src, dim_t len, float alpha,
|
||||
float beta) {
|
||||
inline void rvv_eltwise_apply_bwd_f32(alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, dim_t len, float alpha,
|
||||
float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_bwd_f32(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_bwd_f32] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_bwd_kernel_f32(diff_src, diff_dst, src, len, alpha, beta, eval);
|
||||
rvv_eltwise_bwd_kernel_f32(
|
||||
diff_src, diff_dst, src, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_bwd_s32(alg_kind_t alg, int32_t *diff_src,
|
||||
const int32_t *diff_dst, const int32_t *src, dim_t len, float alpha,
|
||||
float beta) {
|
||||
inline void rvv_eltwise_apply_bwd_s32(alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, dim_t len, float alpha,
|
||||
float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_bwd_s32(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_bwd_s32] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_bwd_kernel_s32(diff_src, diff_dst, src, len, alpha, beta, eval);
|
||||
rvv_eltwise_bwd_kernel_s32(
|
||||
diff_src, diff_dst, src, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_bwd_s8(alg_kind_t alg, int8_t *diff_src,
|
||||
const int8_t *diff_dst, const int8_t *src, dim_t len, float alpha,
|
||||
float beta) {
|
||||
inline void rvv_eltwise_apply_bwd_s8(alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, dim_t len, float alpha,
|
||||
float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_bwd_s8(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_bwd_s8] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_bwd_kernel_s8(diff_src, diff_dst, src, len, alpha, beta, eval);
|
||||
rvv_eltwise_bwd_kernel_s8(
|
||||
diff_src, diff_dst, src, len, alpha, beta, eval, dt);
|
||||
}
|
||||
inline void rvv_eltwise_apply_bwd_u8(alg_kind_t alg, uint8_t *diff_src,
|
||||
const uint8_t *diff_dst, const uint8_t *src, dim_t len, float alpha,
|
||||
float beta) {
|
||||
inline void rvv_eltwise_apply_bwd_u8(alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, dim_t len, float alpha,
|
||||
float beta, const data_type_t dt) {
|
||||
auto eval = get_eval_bwd_u8(alg);
|
||||
if (!eval) {
|
||||
assert(!"[rvv_eltwise_apply_bwd_u8] unknown eltwise alg_kind");
|
||||
return;
|
||||
}
|
||||
rvv_eltwise_bwd_kernel_u8(diff_src, diff_dst, src, len, alpha, beta, eval);
|
||||
rvv_eltwise_bwd_kernel_u8(
|
||||
diff_src, diff_dst, src, len, alpha, beta, eval, dt);
|
||||
}
|
||||
|
||||
} // namespace rv64
|
||||
|
Reference in New Issue
Block a user