mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
cpu: rv64: eltwise: fix templatization and ifdef line
This commit is contained in:
committed by
Vadim Pirogov
parent
a0961ab37e
commit
d3eb5ed7ac
@ -32,7 +32,7 @@ using namespace dnnl::impl::cpu::x64;
|
||||
#endif // DNNL_AARCH64_USE_ACL
|
||||
using namespace dnnl::impl::cpu::aarch64;
|
||||
#elif DNNL_RV64
|
||||
#if DNNL_RISCV_USE_RVV_INTRINSICS
|
||||
#ifdef DNNL_RISCV_USE_RVV_INTRINSICS
|
||||
#include "cpu/rv64/rvv_eltwise.hpp"
|
||||
using namespace dnnl::impl::cpu::rv64;
|
||||
#endif // DNNL_RISCV_USE_RVV_INTRINSICS
|
||||
@ -81,11 +81,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
|
||||
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, s8>)
|
||||
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, u8>)
|
||||
CPU_INSTANCE_AARCH64_ACL(acl_eltwise_fwd_t)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f32>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f16>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s32>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s8>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<u8>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t)
|
||||
CPU_INSTANCE(ref_eltwise_fwd_t<f32>)
|
||||
CPU_INSTANCE(ref_eltwise_fwd_t<bf16>)
|
||||
CPU_INSTANCE(ref_eltwise_fwd_t<f16>)
|
||||
@ -108,11 +104,7 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
|
||||
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<avx, f32>)
|
||||
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<sse41, f32>)
|
||||
CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t<sve_128, f32>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f32>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f16>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s32>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s8>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<u8>)
|
||||
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t)
|
||||
CPU_INSTANCE(ref_eltwise_bwd_t<f32>)
|
||||
CPU_INSTANCE(ref_eltwise_bwd_t<bf16>)
|
||||
CPU_INSTANCE(ref_eltwise_bwd_t<f16>)
|
||||
|
@ -22,17 +22,15 @@
|
||||
#include "common/math_utils.hpp"
|
||||
#include "common/type_helpers.hpp"
|
||||
|
||||
#include "cpu/primitive_attr_postops.hpp"
|
||||
#include "cpu/rv64/rvv_eltwise_kernels.hpp"
|
||||
|
||||
#include "cpu/rv64/rvv_eltwise.hpp"
|
||||
#include "cpu/rv64/rvv_eltwise_kernels.hpp"
|
||||
|
||||
namespace dnnl {
|
||||
namespace impl {
|
||||
namespace cpu {
|
||||
namespace rv64 {
|
||||
|
||||
// Data type dispatch for RVV eltwise forward (per-dtype apply)
|
||||
// Data type dispatch for RVV eltwise forward
|
||||
static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
|
||||
const void *src, void *dst, const float alpha, const float beta,
|
||||
const dim_t len, const data_type_t dt) {
|
||||
@ -64,7 +62,7 @@ static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
|
||||
}
|
||||
}
|
||||
|
||||
// Data type dispatch for RVV eltwise backward (per-dtype apply)
|
||||
// Data type dispatch for RVV eltwise backward
|
||||
static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, const float alpha,
|
||||
const float beta, const dim_t len, const data_type_t dt) {
|
||||
@ -101,34 +99,36 @@ static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
}
|
||||
|
||||
// Forward execute
|
||||
template <data_type_t data_type>
|
||||
status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
|
||||
const exec_ctx_t &ctx) const {
|
||||
status_t rvv_eltwise_fwd_t::execute(const exec_ctx_t &ctx) const {
|
||||
if (pd()->has_zero_dim_memory()) return status::success;
|
||||
|
||||
status_t status = status::success;
|
||||
auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
|
||||
auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
|
||||
const void *src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
|
||||
void *dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
|
||||
CHECK(status);
|
||||
|
||||
const memory_desc_wrapper data_d(pd()->src_md());
|
||||
const auto nelems = data_d.nelems(true);
|
||||
const memory_desc_wrapper src_d(pd()->src_md());
|
||||
const memory_desc_wrapper dst_d(pd()->dst_md());
|
||||
const auto nelems = dst_d.nelems(true);
|
||||
|
||||
const auto alg_kind = pd()->desc()->alg_kind;
|
||||
const float alpha = pd()->desc()->alpha;
|
||||
const float beta = pd()->desc()->beta;
|
||||
|
||||
if (pd()->use_dense_) {
|
||||
src += data_d.offset0();
|
||||
dst += data_d.offset0();
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const char *src_base
|
||||
= static_cast<const char *>(src) + src_d.offset0() * esize;
|
||||
char *dst_base = static_cast<char *>(dst) + dst_d.offset0() * esize;
|
||||
|
||||
parallel(0, [&](const int ithr, const int nthr) {
|
||||
dim_t start = 0, end = 0;
|
||||
balance211(nelems, nthr, ithr, start, end);
|
||||
if (start == end) return;
|
||||
|
||||
const void *thr_src = static_cast<const void *>(src + start);
|
||||
void *thr_dst = static_cast<void *>(dst + start);
|
||||
const void *thr_src
|
||||
= static_cast<const void *>(src_base + start * esize);
|
||||
void *thr_dst = static_cast<void *>(dst_base + start * esize);
|
||||
const dim_t len = end - start;
|
||||
|
||||
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
|
||||
@ -138,22 +138,26 @@ status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
|
||||
return status::success;
|
||||
}
|
||||
|
||||
// nCspBc padded path: iterate over blocks and handle tail with zero-preserve
|
||||
if (pd()->use_nCspBc_padded_) {
|
||||
const blocking_desc_t &blk = data_d.blocking_desc();
|
||||
const blocking_desc_t &blk = src_d.blocking_desc();
|
||||
const dim_t block = blk.inner_blks[0];
|
||||
|
||||
const dim_t MB = pd()->MB();
|
||||
const dim_t C = pd()->C() / block;
|
||||
const dim_t C_PADDED = data_d.padded_dims()[1] / block;
|
||||
const dim_t C_PADDED = src_d.padded_dims()[1] / block;
|
||||
const dim_t tail = pd()->C() % block;
|
||||
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
|
||||
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const char *src_bytes = static_cast<const char *>(src);
|
||||
char *dst_bytes = static_cast<char *>(dst);
|
||||
|
||||
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
|
||||
auto d_off = (n * C_PADDED * SP + c * SP + sp) * block;
|
||||
|
||||
const void *thr_src = static_cast<const void *>(src + d_off);
|
||||
void *thr_dst = static_cast<void *>(dst + d_off);
|
||||
const void *thr_src
|
||||
= static_cast<const void *>(src_bytes + d_off * esize);
|
||||
void *thr_dst = static_cast<void *>(dst_bytes + d_off * esize);
|
||||
|
||||
if (c < C) {
|
||||
// full block
|
||||
@ -176,31 +180,28 @@ status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
|
||||
}
|
||||
|
||||
// Backward execute
|
||||
template <data_type_t data_type>
|
||||
status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
|
||||
const exec_ctx_t &ctx) const {
|
||||
if (pd()->has_zero_dim_memory()) return status::success;
|
||||
|
||||
status_t status = status::success;
|
||||
auto data = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
|
||||
: CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
|
||||
auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
|
||||
auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
|
||||
CHECK(status);
|
||||
status_t rvv_eltwise_bwd_t::execute(const exec_ctx_t &ctx) const {
|
||||
auto data = pd()->use_dst() ? CTX_IN_MEM(const void *, DNNL_ARG_DST)
|
||||
: CTX_IN_MEM(const void *, DNNL_ARG_SRC);
|
||||
auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
|
||||
auto diff_src = CTX_OUT_MEM(void *, DNNL_ARG_DIFF_SRC);
|
||||
|
||||
const memory_desc_wrapper data_d(pd()->data_md());
|
||||
const memory_desc_wrapper diff_d(pd()->diff_src_md());
|
||||
const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
|
||||
const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
|
||||
|
||||
const auto nelems = diff_d.nelems(true);
|
||||
const auto nelems = diff_src_d.nelems(true);
|
||||
const auto alg_kind = pd()->desc()->alg_kind;
|
||||
const float alpha = pd()->desc()->alpha;
|
||||
const float beta = pd()->desc()->beta;
|
||||
|
||||
if (pd()->use_dense_) {
|
||||
const dim_t off = diff_d.offset0();
|
||||
data_t *ds_ptr = diff_src + off;
|
||||
const data_t *dd_ptr = diff_dst + off;
|
||||
const data_t *data_ptr = data + off;
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const dim_t off = diff_src_d.offset0();
|
||||
char *ds_bytes = static_cast<char *>(diff_src) + off * esize;
|
||||
const char *dd_bytes
|
||||
= static_cast<const char *>(diff_dst) + off * esize;
|
||||
const char *data_bytes = static_cast<const char *>(data) + off * esize;
|
||||
|
||||
parallel(0, [&](const int ithr, const int nthr) {
|
||||
dim_t start = 0, end = 0;
|
||||
@ -208,10 +209,10 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
|
||||
if (start == end) return;
|
||||
|
||||
compute_eltwise_rvv_bwd(alg_kind,
|
||||
static_cast<void *>(ds_ptr + start),
|
||||
static_cast<const void *>(dd_ptr + start),
|
||||
static_cast<const void *>(data_ptr + start), alpha, beta,
|
||||
end - start, pd()->src_md()->data_type);
|
||||
static_cast<void *>(ds_bytes + start * esize),
|
||||
static_cast<const void *>(dd_bytes + start * esize),
|
||||
static_cast<const void *>(data_bytes + start * esize),
|
||||
alpha, beta, end - start, pd()->src_md()->data_type);
|
||||
});
|
||||
return status::success;
|
||||
}
|
||||
@ -226,19 +227,24 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
|
||||
const dim_t tail = pd()->C() % block;
|
||||
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
|
||||
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const char *data_bytes = static_cast<const char *>(data);
|
||||
const char *dd_bytes = static_cast<const char *>(diff_dst);
|
||||
char *ds_bytes = static_cast<char *>(diff_src);
|
||||
|
||||
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
|
||||
auto base_off = (n * C_PADDED * SP + c * SP + sp) * block;
|
||||
|
||||
auto data_p = data + base_off;
|
||||
auto dd_p = diff_dst + base_off;
|
||||
auto ds_p = diff_src + base_off;
|
||||
const void *data_p
|
||||
= static_cast<const void *>(data_bytes + base_off * esize);
|
||||
const void *dd_p
|
||||
= static_cast<const void *>(dd_bytes + base_off * esize);
|
||||
void *ds_p = static_cast<void *>(ds_bytes + base_off * esize);
|
||||
|
||||
const dim_t len = (c < C) ? block : tail;
|
||||
if (len == 0) return;
|
||||
compute_eltwise_rvv_bwd(alg_kind, static_cast<void *>(ds_p),
|
||||
static_cast<const void *>(dd_p),
|
||||
static_cast<const void *>(data_p), alpha, beta, len,
|
||||
pd()->src_md()->data_type);
|
||||
compute_eltwise_rvv_bwd(alg_kind, ds_p, dd_p, data_p, alpha, beta,
|
||||
len, pd()->src_md()->data_type);
|
||||
});
|
||||
return status::success;
|
||||
}
|
||||
@ -246,18 +252,6 @@ status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
|
||||
return status::unimplemented;
|
||||
}
|
||||
|
||||
template struct rvv_eltwise_fwd_t<data_type::f32>;
|
||||
template struct rvv_eltwise_fwd_t<data_type::f16>;
|
||||
template struct rvv_eltwise_fwd_t<data_type::s32>;
|
||||
template struct rvv_eltwise_fwd_t<data_type::s8>;
|
||||
template struct rvv_eltwise_fwd_t<data_type::u8>;
|
||||
|
||||
template struct rvv_eltwise_bwd_t<data_type::f32>;
|
||||
template struct rvv_eltwise_bwd_t<data_type::f16>;
|
||||
template struct rvv_eltwise_bwd_t<data_type::s32>;
|
||||
template struct rvv_eltwise_bwd_t<data_type::s8>;
|
||||
template struct rvv_eltwise_bwd_t<data_type::u8>;
|
||||
|
||||
} // namespace rv64
|
||||
} // namespace cpu
|
||||
} // namespace impl
|
||||
|
@ -32,10 +32,6 @@ namespace impl {
|
||||
namespace cpu {
|
||||
namespace rv64 {
|
||||
|
||||
// RVV forward eltwise primitive (RV64GCV). Single unified path.
|
||||
// Key compute kernels are intentionally left for RVV intrinsics implementation.
|
||||
|
||||
template <impl::data_type_t data_type>
|
||||
struct rvv_eltwise_fwd_t : public primitive_t {
|
||||
struct pd_t : public cpu_eltwise_fwd_pd_t {
|
||||
using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
|
||||
@ -44,136 +40,121 @@ struct rvv_eltwise_fwd_t : public primitive_t {
|
||||
|
||||
status_t init(engine_t *engine) {
|
||||
UNUSED(engine);
|
||||
using namespace utils;
|
||||
|
||||
const memory_desc_wrapper src_d(src_md());
|
||||
const memory_desc_wrapper dst_d(dst_md());
|
||||
|
||||
VDISPATCH_ELTWISE(utils::everyone_is(data_type, src_md()->data_type,
|
||||
dst_md()->data_type),
|
||||
const data_type_t d_type = dst_md()->data_type;
|
||||
using namespace dnnl::impl::data_type;
|
||||
bool type_ok = d_type == f32 || d_type == f16 || d_type == s32
|
||||
|| d_type == s8 || d_type == u8;
|
||||
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(
|
||||
src_md()->data_type == d_type, VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(platform::has_data_type_support(d_type),
|
||||
VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
|
||||
VERBOSE_UNSUPPORTED_DT);
|
||||
|
||||
VDISPATCH_ELTWISE(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
|
||||
VDISPATCH_ELTWISE(
|
||||
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
|
||||
|
||||
// filter out what algs we implemented
|
||||
VDISPATCH_ELTWISE(
|
||||
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
alg_kind::eltwise_square, alg_kind::eltwise_abs,
|
||||
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
|
||||
alg_kind::eltwise_clip,
|
||||
alg_kind::eltwise_hardsigmoid,
|
||||
alg_kind::eltwise_hardswish),
|
||||
"Unsupported alg_kind for rvv extension");
|
||||
|
||||
VDISPATCH_ELTWISE(
|
||||
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
|
||||
VDISPATCH_ELTWISE(
|
||||
src_d == dst_d, VERBOSE_INCONSISTENT_MDS, "src", "dst");
|
||||
VDISPATCH_ELTWISE(check_alg_kind(), VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
// Determine supported memory cases: dense or nCspBc padded tail.
|
||||
use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true)
|
||||
&& IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(),
|
||||
is_zero_preserved());
|
||||
|
||||
use_nCspBc_padded_ = !use_dense_
|
||||
&& src_d.blocking_desc().inner_nblks == 1
|
||||
&& one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& src_d.blocking_desc().inner_idxs[0] == 1
|
||||
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
|
||||
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
|
||||
VERBOSE_UNSUPPORTED_SPARSE_CFG);
|
||||
VDISPATCH_ELTWISE(
|
||||
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
bool use_dense_, use_nCspBc_padded_;
|
||||
|
||||
bool check_alg_kind() const {
|
||||
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
alg_kind::eltwise_square, alg_kind::eltwise_abs,
|
||||
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
|
||||
alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid,
|
||||
alg_kind::eltwise_hardswish);
|
||||
}
|
||||
};
|
||||
|
||||
rvv_eltwise_fwd_t(const pd_t *apd) : primitive_t(apd) {}
|
||||
|
||||
using data_t = typename prec_traits_t<data_type>::type;
|
||||
|
||||
status_t execute(const exec_ctx_t &ctx) const override {
|
||||
return execute_forward(ctx);
|
||||
}
|
||||
status_t execute(const exec_ctx_t &ctx) const;
|
||||
|
||||
private:
|
||||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
|
||||
status_t execute_forward(const exec_ctx_t &ctx) const;
|
||||
};
|
||||
|
||||
template <impl::data_type_t data_type>
|
||||
struct rvv_eltwise_bwd_t : public primitive_t {
|
||||
struct pd_t : public cpu_eltwise_bwd_pd_t {
|
||||
using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
|
||||
|
||||
DECLARE_COMMON_PD_T_("rv64gcv", rvv_eltwise_bwd_t)
|
||||
DECLARE_COMMON_PD_T_("RISCV64GCV", rvv_eltwise_bwd_t)
|
||||
|
||||
status_t init(engine_t *engine) {
|
||||
UNUSED(engine);
|
||||
using namespace utils;
|
||||
using namespace data_type;
|
||||
|
||||
const memory_desc_wrapper diff_src_d(diff_src_md());
|
||||
const memory_desc_wrapper diff_dst_d(diff_dst_md());
|
||||
const memory_desc_wrapper data_d(data_md());
|
||||
const memory_desc_wrapper src_d(data_md());
|
||||
|
||||
const data_type_t d_type = src_md()->data_type;
|
||||
using namespace dnnl::impl::data_type;
|
||||
bool type_ok = d_type == f32 || d_type == f16 || d_type == s32
|
||||
|| d_type == s8 || d_type == u8;
|
||||
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(
|
||||
utils::everyone_is(data_type, data_md()->data_type,
|
||||
diff_src_md()->data_type, diff_dst_md()->data_type),
|
||||
utils::everyone_is(d_type, diff_src_md()->data_type,
|
||||
diff_dst_md()->data_type),
|
||||
VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
|
||||
VDISPATCH_ELTWISE(platform::has_data_type_support(d_type),
|
||||
VERBOSE_UNSUPPORTED_DT);
|
||||
|
||||
VDISPATCH_ELTWISE(
|
||||
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
|
||||
|
||||
// filter out what algs we implemented
|
||||
VDISPATCH_ELTWISE(
|
||||
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
alg_kind::eltwise_square, alg_kind::eltwise_abs,
|
||||
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
|
||||
alg_kind::eltwise_clip,
|
||||
alg_kind::eltwise_hardsigmoid,
|
||||
alg_kind::eltwise_hardswish),
|
||||
"Unsupported alg_kind for rvv extension");
|
||||
|
||||
VDISPATCH_ELTWISE(
|
||||
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
|
||||
VDISPATCH_ELTWISE(diff_dst_d == diff_src_d,
|
||||
VDISPATCH_ELTWISE(diff_src_d == diff_dst_d,
|
||||
VERBOSE_INCONSISTENT_MDS, "diff_src", "diff_dst");
|
||||
VDISPATCH_ELTWISE(check_alg_kind(), VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
// Layout support: dense or nCspBc-padded (blocked C with tail)
|
||||
use_dense_ = diff_dst_d.is_dense()
|
||||
|| (diff_dst_d.is_dense(true) && is_zero_preserved());
|
||||
use_nCspBc_padded_ = !use_dense_
|
||||
&& data_d.blocking_desc().inner_nblks == 1
|
||||
&& one_of(data_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& data_d.blocking_desc().inner_idxs[0] == 1
|
||||
&& data_d.only_padded_dim(1) && data_d.is_dense(true);
|
||||
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
|
||||
VERBOSE_UNSUPPORTED_SPARSE_CFG);
|
||||
&& src_d.blocking_desc().inner_nblks == 1
|
||||
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& src_d.blocking_desc().inner_idxs[0] == 1
|
||||
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
|
||||
VDISPATCH_ELTWISE(
|
||||
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
bool use_dense_, use_nCspBc_padded_;
|
||||
|
||||
bool check_alg_kind() const {
|
||||
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
alg_kind::eltwise_square, alg_kind::eltwise_abs,
|
||||
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
|
||||
alg_kind::eltwise_clip, alg_kind::eltwise_hardsigmoid,
|
||||
alg_kind::eltwise_hardswish);
|
||||
}
|
||||
};
|
||||
|
||||
rvv_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {}
|
||||
|
||||
using data_t = typename prec_traits_t<data_type>::type;
|
||||
|
||||
status_t execute(const exec_ctx_t &ctx) const override {
|
||||
return execute_backward(ctx);
|
||||
}
|
||||
status_t execute(const exec_ctx_t &ctx) const;
|
||||
|
||||
private:
|
||||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
|
||||
status_t execute_backward(const exec_ctx_t &ctx) const;
|
||||
};
|
||||
|
||||
} // namespace rv64
|
||||
|
Reference in New Issue
Block a user