mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 10:03:50 +08:00
cpu: rv64: eltwise: remove channel blocked layouts support
This commit is contained in:
committed by
Vadim Pirogov
parent
a6956f00d4
commit
de3f018b62
@ -30,10 +30,12 @@ namespace impl {
|
||||
namespace cpu {
|
||||
namespace rv64 {
|
||||
|
||||
namespace {
|
||||
|
||||
// Data type dispatch for RVV eltwise forward
|
||||
static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
|
||||
const void *src, void *dst, const float alpha, const float beta,
|
||||
const dim_t len, const data_type_t dt) {
|
||||
void compute_eltwise_rvv_fwd(const alg_kind_t alg, const void *src, void *dst,
|
||||
const float alpha, const float beta, const dim_t len,
|
||||
const data_type_t dt) {
|
||||
switch (dt) {
|
||||
case data_type::f32:
|
||||
rvv_eltwise_apply_fwd_f32(alg, src, dst, len, alpha, beta, dt);
|
||||
@ -52,7 +54,7 @@ static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
|
||||
}
|
||||
|
||||
// Data type dispatch for RVV eltwise backward
|
||||
static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
const void *diff_dst, const void *src, const float alpha,
|
||||
const float beta, const dim_t len, const data_type_t dt) {
|
||||
switch (dt) {
|
||||
@ -76,6 +78,8 @@ static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
|
||||
}
|
||||
}
|
||||
|
||||
} // unnamed namespace
|
||||
|
||||
// Forward execute
|
||||
status_t rvv_eltwise_fwd_t::execute(const exec_ctx_t &ctx) const {
|
||||
if (pd()->has_zero_dim_memory()) return status::success;
|
||||
@ -112,49 +116,8 @@ status_t rvv_eltwise_fwd_t::execute(const exec_ctx_t &ctx) const {
|
||||
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
|
||||
len, pd()->src_md()->data_type);
|
||||
});
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
if (pd()->use_nCspBc_padded_) {
|
||||
const blocking_desc_t &blk = src_d.blocking_desc();
|
||||
const dim_t block = blk.inner_blks[0];
|
||||
|
||||
const dim_t MB = pd()->MB();
|
||||
const dim_t C = pd()->C() / block;
|
||||
const dim_t C_PADDED = src_d.padded_dims()[1] / block;
|
||||
const dim_t tail = pd()->C() % block;
|
||||
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
|
||||
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const char *src_bytes = static_cast<const char *>(src);
|
||||
char *dst_bytes = static_cast<char *>(dst);
|
||||
|
||||
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
|
||||
auto d_off = (n * C_PADDED * SP + c * SP + sp) * block;
|
||||
|
||||
const void *thr_src
|
||||
= static_cast<const void *>(src_bytes + d_off * esize);
|
||||
void *thr_dst = static_cast<void *>(dst_bytes + d_off * esize);
|
||||
|
||||
if (c < C) {
|
||||
// full block
|
||||
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
|
||||
block, pd()->src_md()->data_type);
|
||||
} else {
|
||||
// tail: process only valid channels, keep padding zero-preserved
|
||||
const dim_t len = tail;
|
||||
if (len > 0) {
|
||||
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha,
|
||||
beta, len, pd()->src_md()->data_type);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
return status::unimplemented;
|
||||
return status::success;
|
||||
}
|
||||
|
||||
// Backward execute
|
||||
@ -192,42 +155,8 @@ status_t rvv_eltwise_bwd_t::execute(const exec_ctx_t &ctx) const {
|
||||
static_cast<const void *>(data_bytes + start * esize),
|
||||
alpha, beta, end - start, pd()->src_md()->data_type);
|
||||
});
|
||||
return status::success;
|
||||
}
|
||||
|
||||
if (pd()->use_nCspBc_padded_) {
|
||||
const blocking_desc_t &blk = data_d.blocking_desc();
|
||||
const dim_t block = blk.inner_blks[0];
|
||||
|
||||
const dim_t MB = pd()->MB();
|
||||
const dim_t C = pd()->C() / block;
|
||||
const dim_t C_PADDED = data_d.padded_dims()[1] / block;
|
||||
const dim_t tail = pd()->C() % block;
|
||||
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
|
||||
|
||||
const size_t esize = types::data_type_size(pd()->src_md()->data_type);
|
||||
const char *data_bytes = static_cast<const char *>(data);
|
||||
const char *dd_bytes = static_cast<const char *>(diff_dst);
|
||||
char *ds_bytes = static_cast<char *>(diff_src);
|
||||
|
||||
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
|
||||
auto base_off = (n * C_PADDED * SP + c * SP + sp) * block;
|
||||
|
||||
const void *data_p
|
||||
= static_cast<const void *>(data_bytes + base_off * esize);
|
||||
const void *dd_p
|
||||
= static_cast<const void *>(dd_bytes + base_off * esize);
|
||||
void *ds_p = static_cast<void *>(ds_bytes + base_off * esize);
|
||||
|
||||
const dim_t len = (c < C) ? block : tail;
|
||||
if (len == 0) return;
|
||||
compute_eltwise_rvv_bwd(alg_kind, ds_p, dd_p, data_p, alpha, beta,
|
||||
len, pd()->src_md()->data_type);
|
||||
});
|
||||
return status::success;
|
||||
}
|
||||
|
||||
return status::unimplemented;
|
||||
return status::success;
|
||||
}
|
||||
|
||||
} // namespace rv64
|
||||
|
@ -17,8 +17,6 @@
|
||||
#ifndef CPU_RV64_RVV_ELTWISE_HPP
|
||||
#define CPU_RV64_RVV_ELTWISE_HPP
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
#include "common/c_types_map.hpp"
|
||||
#include "common/primitive.hpp"
|
||||
#include "common/type_helpers.hpp"
|
||||
@ -46,8 +44,7 @@ struct rvv_eltwise_fwd_t : public primitive_t {
|
||||
|
||||
const data_type_t d_type = dst_md()->data_type;
|
||||
using namespace dnnl::impl::data_type;
|
||||
bool type_ok = d_type == f32 || d_type == s32 || d_type == s8
|
||||
|| d_type == u8;
|
||||
bool type_ok = utils::one_of(d_type, f32, s32, s8, u8);
|
||||
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(
|
||||
src_md()->data_type == d_type, VERBOSE_UNSUPPORTED_DT);
|
||||
@ -65,18 +62,12 @@ struct rvv_eltwise_fwd_t : public primitive_t {
|
||||
use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true)
|
||||
&& IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(),
|
||||
is_zero_preserved());
|
||||
use_nCspBc_padded_ = !use_dense_
|
||||
&& src_d.blocking_desc().inner_nblks == 1
|
||||
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& src_d.blocking_desc().inner_idxs[0] == 1
|
||||
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
|
||||
VDISPATCH_ELTWISE(
|
||||
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
|
||||
VDISPATCH_ELTWISE(use_dense_, VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
bool use_dense_, use_nCspBc_padded_;
|
||||
bool use_dense_;
|
||||
|
||||
bool check_alg_kind() const {
|
||||
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
@ -109,8 +100,7 @@ struct rvv_eltwise_bwd_t : public primitive_t {
|
||||
|
||||
const data_type_t d_type = src_md()->data_type;
|
||||
using namespace dnnl::impl::data_type;
|
||||
bool type_ok = d_type == f32 || d_type == s32 || d_type == s8
|
||||
|| d_type == u8;
|
||||
bool type_ok = utils::one_of(d_type, f32, s32, s8, u8);
|
||||
VDISPATCH_ELTWISE(type_ok, VERBOSE_UNSUPPORTED_DT);
|
||||
VDISPATCH_ELTWISE(
|
||||
utils::everyone_is(d_type, diff_src_md()->data_type,
|
||||
@ -128,18 +118,12 @@ struct rvv_eltwise_bwd_t : public primitive_t {
|
||||
|
||||
use_dense_ = diff_dst_d.is_dense()
|
||||
|| (diff_dst_d.is_dense(true) && is_zero_preserved());
|
||||
use_nCspBc_padded_ = !use_dense_
|
||||
&& src_d.blocking_desc().inner_nblks == 1
|
||||
&& utils::one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
|
||||
&& src_d.blocking_desc().inner_idxs[0] == 1
|
||||
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
|
||||
VDISPATCH_ELTWISE(
|
||||
use_dense_ || use_nCspBc_padded_, VERBOSE_UNSUPPORTED_TAG);
|
||||
VDISPATCH_ELTWISE(use_dense_, VERBOSE_UNSUPPORTED_TAG);
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
bool use_dense_, use_nCspBc_padded_;
|
||||
bool use_dense_;
|
||||
|
||||
bool check_alg_kind() const {
|
||||
return utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
|
||||
|
Reference in New Issue
Block a user