cpu: rv64: eltwise: fix templatization and ifdef line

This commit is contained in:
张健10355098
2025-09-15 15:58:11 +08:00
committed by Vadim Pirogov
parent a0961ab37e
commit d3eb5ed7ac
3 changed files with 107 additions and 140 deletions

View File

@ -32,7 +32,7 @@ using namespace dnnl::impl::cpu::x64;
#endif // DNNL_AARCH64_USE_ACL
using namespace dnnl::impl::cpu::aarch64;
#elif DNNL_RV64
#if DNNL_RISCV_USE_RVV_INTRINSICS
#ifdef DNNL_RISCV_USE_RVV_INTRINSICS
#include "cpu/rv64/rvv_eltwise.hpp"
using namespace dnnl::impl::cpu::rv64;
#endif // DNNL_RISCV_USE_RVV_INTRINSICS
@ -81,11 +81,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, s8>)
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, u8>)
CPU_INSTANCE_AARCH64_ACL(acl_eltwise_fwd_t)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f16>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<u8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t)
CPU_INSTANCE(ref_eltwise_fwd_t<f32>)
CPU_INSTANCE(ref_eltwise_fwd_t<bf16>)
CPU_INSTANCE(ref_eltwise_fwd_t<f16>)
@ -108,11 +104,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<avx, f32>)
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<sse41, f32>)
CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t<sve_128, f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f16>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<u8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t)
CPU_INSTANCE(ref_eltwise_bwd_t<f32>)
CPU_INSTANCE(ref_eltwise_bwd_t<bf16>)
CPU_INSTANCE(ref_eltwise_bwd_t<f16>)

View File

@ -22,17 +22,15 @@
#include "common/math_utils.hpp"
#include "common/type_helpers.hpp"
#include "cpu/primitive_attr_postops.hpp"
#include "cpu/rv64/rvv_eltwise_kernels.hpp"
#include "cpu/rv64/rvv_eltwise.hpp"
#include "cpu/rv64/rvv_eltwise_kernels.hpp"
namespace dnnl {
namespace impl {
namespace cpu {
namespace rv64 {
// Data type dispatch for RVV eltwise forward (per-dtype apply)
// Data type dispatch for RVV eltwise forward
static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
const void *src, void *dst, const float alpha, const float beta,
const dim_t len, const data_type_t dt) {
@ -64,7 +62,7 @@ static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
}
}
// Data type dispatch for RVV eltwise backward (per-dtype apply)
// Data type dispatch for RVV eltwise backward
static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
const void *diff_dst, const void *src, const float alpha,
const float beta, const dim_t len, const data_type_t dt) {
@ -101,34 +99,36 @@ static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
}
// Forward execute
template <data_type_t data_type>
status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
const exec_ctx_t &ctx) const {
status_t rvv_eltwise_fwd_t::execute(const exec_ctx_t &ctx) const {
if (pd()->has_zero_dim_memory()) return status::success;
status_t status = status::success;
auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
const void *src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
void *dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
CHECK(status);
const memory_desc_wrapper data_d(pd()->src_md());
const auto nelems = data_d.nelems(true);
const memory_desc_wrapper src_d(pd()->src_md());
const memory_desc_wrapper dst_d(pd()->dst_md());
const auto nelems = dst_d.nelems(true);
const auto alg_kind = pd()->desc()->alg_kind;
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;
if (pd()->use_dense_) {
src += data_d.offset0();
dst += data_d.offset0();
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
const char *src_base
= static_cast<const char *>(src) + src_d.offset0() * esize;
char *dst_base = static_cast<char *>(dst) + dst_d.offset0() * esize;
parallel(0, [&](const int ithr, const int nthr) {
dim_t start = 0, end = 0;
balance211(nelems, nthr, ithr, start, end);
if (start == end) return;
const void *thr_src = static_cast<const void *>(src + start);
void *thr_dst = static_cast<void *>(dst + start);
const void *thr_src
= static_cast<const void *>(src_base + start * esize);
void *thr_dst = static_cast<void *>(dst_base + start * esize);
const dim_t len = end - start;
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
@ -138,22 +138,26 @@ status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
return status::success;
}
// nCspBc padded path: iterate over blocks and handle tail with zero-preserve
if (pd()->use_nCspBc_padded_) {
const blocking_desc_t &blk = data_d.blocking_desc();
const blocking_desc_t &blk = src_d.blocking_desc();
const dim_t block = blk.inner_blks[0];
const dim_t MB = pd()->MB();
const dim_t C = pd()->C() / block;
const dim_t C_PADDED = data_d.padded_dims()[1] / block;
const dim_t C_PADDED = src_d.padded_dims()[1] / block;
const dim_t tail = pd()->C() % block;
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
const char *src_bytes = static_cast<const char *>(src);
char *dst_bytes = static_cast<char *>(dst);
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
auto d_off = (n * C_PADDED * SP + c * SP + sp) * block;
const void *thr_src = static_cast<const void *>(src + d_off);
void *thr_dst = static_cast<void *>(dst + d_off);
const void *thr_src
= static_cast<const void *>(src_bytes + d_off * esize);
void *thr_dst = static_cast<void *>(dst_bytes + d_off * esize);
if (c < C) {
// full block
@ -176,31 +180,28 @@ status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
}
// Backward execute
template <data_type_t data_type>
status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
const exec_ctx_t &ctx) const {
if (pd()->has_zero_dim_memory()) return status::success;
status_t status = status::success;
auto data = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
: CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
CHECK(status);
status_t rvv_eltwise_bwd_t::execute(const exec_ctx_t &ctx) const {
auto data = pd()->use_dst() ? CTX_IN_MEM(const void *, DNNL_ARG_DST)
: CTX_IN_MEM(const void *, DNNL_ARG_SRC);
auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
const memory_desc_wrapper data_d(pd()->data_md());
const memory_desc_wrapper diff_d(pd()->diff_src_md());
const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
const auto nelems = diff_d.nelems(true);
const auto nelems = diff_src_d.nelems(true);
const auto alg_kind = pd()->desc()->alg_kind;
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;
if (pd()->use_dense_) {
const dim_t off = diff_d.offset0();
data_t *ds_ptr = diff_src + off;
const data_t *dd_ptr = diff_dst + off;
const data_t *data_ptr = data + off;
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
const dim_t off = diff_src_d.offset0();
char *ds_bytes = static_cast<char *>(diff_src) + off * esize;
const char *dd_bytes
= static_cast<const char *>(diff_dst) + off * esize;
const char *data_bytes = static_cast<const char *>(data) + off * esize;
parallel(0, [&](const int ithr, const int nthr) {
dim_t start = 0, end = 0;
@ -208,10 +209,10 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
if (start == end) return;
compute_eltwise_rvv_bwd(alg_kind,
static_cast<void *>(ds_ptr + start),
static_cast<const void *>(dd_ptr + start),
static_cast<const void *>(data_ptr + start), alpha, beta,
end - start, pd()->src_md()->data_type);
static_cast<void *>(ds_bytes + start * esize),
static_cast<const void *>(dd_bytes + start * esize),
static_cast<const void *>(data_bytes + start * esize),
alpha, beta, end - start, pd()->src_md()->data_type);
});
return status::success;
}
@ -226,19 +227,24 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
const dim_t tail = pd()->C() % block;
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
const char *data_bytes = static_cast<const char *>(data);
const char *dd_bytes = static_cast<const char *>(diff_dst);
char *ds_bytes = static_cast<char *>(diff_src);
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
auto base_off = (n * C_PADDED * SP + c * SP + sp) * block;
auto data_p = data + base_off;
auto dd_p = diff_dst + base_off;
auto ds_p = diff_src + base_off;
const void *data_p
= static_cast<const void *>(data_bytes + base_off * esize);
const void *dd_p
= static_cast<const void *>(dd_bytes + base_off * esize);
void *ds_p = static_cast<void *>(ds_bytes + base_off * esize);
const dim_t len = (c < C) ? block : tail;
if (len == 0) return;
compute_eltwise_rvv_bwd(alg_kind, static_cast<void *>(ds_p),
static_cast<const void *>(dd_p),
static_cast<const void *>(data_p), alpha, beta, len,
pd()->src_md()->data_type);
compute_eltwise_rvv_bwd(alg_kind, ds_p, dd_p, data_p, alpha, beta,
len, pd()->src_md()->data_type);
});
return status::success;
}
@ -246,18 +252,6 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
return status::unimplemented;
}
template struct rvv_eltwise_fwd_t<data_type::f32>;
template struct rvv_eltwise_fwd_t<data_type::f16>;
template struct rvv_eltwise_fwd_t<data_type::s32>;
template struct rvv_eltwise_fwd_t<data_type::s8>;
template struct rvv_eltwise_fwd_t<data_type::u8>;
template struct rvv_eltwise_bwd_t<data_type::f32>;
template struct rvv_eltwise_bwd_t<data_type::f16>;
template struct rvv_eltwise_bwd_t<data_type::s32>;
template struct rvv_eltwise_bwd_t<data_type::s8>;
template struct rvv_eltwise_bwd_t<data_type::u8>;
} // namespace rv64
} // namespace cpu
} // namespace impl

View File

@ -32,10 +32,6 @@ namespace impl {
namespace cpu {
namespace rv64 {
// RVV forward eltwise primitive (RV64GCV). Single unified path.
// Key compute kernels are intentionally left for RVV intrinsics implementation.
template <impl::data_type_t data_type>
struct rvv_eltwise_fwd_t : public primitive_t {
struct pd_t : public cpu_eltwise_fwd_pd_t {
using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
@ -44,136 +40,121 @@ struct rvv_eltwise_fwd_t : public primitive_t {
status_t init(engine_t *engine) {
UNUSED(engine);
using namespace utils;
const memory_desc_wrapper src_d(src_md());
const memory_desc_wrapper dst_d(dst_md());
VDISPATCH_ELTWISE(utils::everyone_is(data_type, src_md()->data_type,
dst_md()->data_type),
const data_type_t d_type = dst_md()->data_type;
using namespace dnnl::impl::data_type;
bool type_ok = d_type == f32 || d_type == f16 || d_type == s32
|| d_type == s8 || d_type == u8;
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(
src_md()->data_type == d_type, VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(d_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
VDISPATCH_ELTWISE(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
// filter out what algs we implemented
VDISPATCH_ELTWISE(
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip,
alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish),
"Unsupported alg_kind for rvv extension");
VDISPATCH_ELTWISE(
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_ELTWISE(
src_d == dst_d, VERBOSE_INCONSISTENT_MDS, "src", "dst");
VDISPATCH_ELTWISE(check_alg_kind(), VERBOSE_UNSUPPORTED_TAG);
// Determine supported memory cases: dense or nCspBc padded tail.
use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true)
&& IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(),
is_zero_preserved());
use_nCspBc_padded_ = !use_dense_
&& src_d.blocking_desc().inner_nblks == 1
&& one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
&& src_d.blocking_desc().inner_idxs[0] == 1
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
VERBOSE_UNSUPPORTED_SPARSE_CFG);
VDISPATCH_ELTWISE(
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
return status::success;
}
bool use_dense_, use_nCspBc_padded_;
bool check_alg_kind() const {
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish);
}
};
rvv_eltwise_fwd_t(const pd_t *apd) : primitive_t(apd) {}
using data_t = typename prec_traits_t<data_type>::type;
status_t execute(const exec_ctx_t &ctx) const override {
return execute_forward(ctx);
}
status_t execute(const exec_ctx_t &ctx) const;
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
status_t execute_forward(const exec_ctx_t &ctx) const;
};
template <impl::data_type_t data_type>
struct rvv_eltwise_bwd_t : public primitive_t {
struct pd_t : public cpu_eltwise_bwd_pd_t {
using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
DECLARE_COMMON_PD_T_("rv64gcv", rvv_eltwise_bwd_t)
DECLARE_COMMON_PD_T_("RISCV64GCV", rvv_eltwise_bwd_t)
status_t init(engine_t *engine) {
UNUSED(engine);
using namespace utils;
using namespace data_type;
const memory_desc_wrapper diff_src_d(diff_src_md());
const memory_desc_wrapper diff_dst_d(diff_dst_md());
const memory_desc_wrapper data_d(data_md());
const memory_desc_wrapper src_d(data_md());
const data_type_t d_type = src_md()->data_type;
using namespace dnnl::impl::data_type;
bool type_ok = d_type == f32 || d_type == f16 || d_type == s32
|| d_type == s8 || d_type == u8;
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(
utils::everyone_is(data_type, data_md()->data_type,
diff_src_md()->data_type, diff_dst_md()->data_type),
utils::everyone_is(d_type, diff_src_md()->data_type,
diff_dst_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
VDISPATCH_ELTWISE(platform::has_data_type_support(d_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
// filter out what algs we implemented
VDISPATCH_ELTWISE(
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip,
alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish),
"Unsupported alg_kind for rvv extension");
VDISPATCH_ELTWISE(
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_ELTWISE(diff_dst_d == diff_src_d,
VDISPATCH_ELTWISE(diff_src_d == diff_dst_d,
VERBOSE_INCONSISTENT_MDS, "diff_src", "diff_dst");
VDISPATCH_ELTWISE(check_alg_kind(), VERBOSE_UNSUPPORTED_TAG);
// Layout support: dense or nCspBc-padded (blocked C with tail)
use_dense_ = diff_dst_d.is_dense()
|| (diff_dst_d.is_dense(true) && is_zero_preserved());
use_nCspBc_padded_ = !use_dense_
&& data_d.blocking_desc().inner_nblks == 1
&& one_of(data_d.blocking_desc().inner_blks[0], 8, 16)
&& data_d.blocking_desc().inner_idxs[0] == 1
&& data_d.only_padded_dim(1) && data_d.is_dense(true);
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
VERBOSE_UNSUPPORTED_SPARSE_CFG);
&& src_d.blocking_desc().inner_nblks == 1
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
&& src_d.blocking_desc().inner_idxs[0] == 1
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
VDISPATCH_ELTWISE(
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
return status::success;
}
bool use_dense_, use_nCspBc_padded_;
bool check_alg_kind() const {
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish);
}
};
rvv_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {}
using data_t = typename prec_traits_t<data_type>::type;
status_t execute(const exec_ctx_t &ctx) const override {
return execute_backward(ctx);
}
status_t execute(const exec_ctx_t &ctx) const;
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
status_t execute_backward(const exec_ctx_t &ctx) const;
};
} // namespace rv64