cpu: rv64: add support for rvv eltwise feature

This commit is contained in:
张健10355098
2025-09-08 10:19:49 +08:00
committed by Vadim Pirogov
parent dca0b6c7d6
commit a0961ab37e
4 changed files with 1649 additions and 0 deletions

View File

@ -31,6 +31,11 @@ using namespace dnnl::impl::cpu::x64;
#include "cpu/aarch64/acl_eltwise.hpp"
#endif // DNNL_AARCH64_USE_ACL
using namespace dnnl::impl::cpu::aarch64;
#elif DNNL_RV64
#if DNNL_RISCV_USE_RVV_INTRINSICS
#include "cpu/rv64/rvv_eltwise.hpp"
using namespace dnnl::impl::cpu::rv64;
#endif // DNNL_RISCV_USE_RVV_INTRINSICS
#endif
namespace dnnl {
@ -76,6 +81,11 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, s8>)
CPU_INSTANCE_AARCH64(jit_uni_eltwise_int_fwd_t<sve_512, u8>)
CPU_INSTANCE_AARCH64_ACL(acl_eltwise_fwd_t)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<f16>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<s8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_fwd_t<u8>)
CPU_INSTANCE(ref_eltwise_fwd_t<f32>)
CPU_INSTANCE(ref_eltwise_fwd_t<bf16>)
CPU_INSTANCE(ref_eltwise_fwd_t<f16>)
@ -98,6 +108,11 @@ const std::map<pk_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map() {
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<avx, f32>)
CPU_INSTANCE_X64(jit_uni_eltwise_bwd_t<sse41, f32>)
CPU_INSTANCE_AARCH64(jit_uni_eltwise_bwd_t<sve_128, f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<f16>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s32>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<s8>)
CPU_INSTANCE_RV64GCV(rvv_eltwise_bwd_t<u8>)
CPU_INSTANCE(ref_eltwise_bwd_t<f32>)
CPU_INSTANCE(ref_eltwise_bwd_t<bf16>)
CPU_INSTANCE(ref_eltwise_bwd_t<f16>)

View File

@ -0,0 +1,264 @@
/*******************************************************************************
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <assert.h>
#include <riscv_vector.h>
#include "common/c_types_map.hpp"
#include "common/dnnl_thread.hpp"
#include "common/math_utils.hpp"
#include "common/type_helpers.hpp"
#include "cpu/primitive_attr_postops.hpp"
#include "cpu/rv64/rvv_eltwise_kernels.hpp"
#include "cpu/rv64/rvv_eltwise.hpp"
namespace dnnl {
namespace impl {
namespace cpu {
namespace rv64 {
// Data type dispatch for RVV eltwise forward (per-dtype apply)
static inline void compute_eltwise_rvv_fwd(const alg_kind_t alg,
const void *src, void *dst, const float alpha, const float beta,
const dim_t len, const data_type_t dt) {
switch (dt) {
case data_type::f32:
rvv_eltwise_apply_fwd_f32(alg, reinterpret_cast<const float *>(src),
reinterpret_cast<float *>(dst), len, alpha, beta);
break;
case data_type::f16:
rvv_eltwise_apply_fwd_f16(alg,
reinterpret_cast<const _Float16 *>(src),
reinterpret_cast<_Float16 *>(dst), len, alpha, beta);
break;
case data_type::s32:
rvv_eltwise_apply_fwd_s32(alg,
reinterpret_cast<const int32_t *>(src),
reinterpret_cast<int32_t *>(dst), len, alpha, beta);
break;
case data_type::s8:
rvv_eltwise_apply_fwd_s8(alg, reinterpret_cast<const int8_t *>(src),
reinterpret_cast<int8_t *>(dst), len, alpha, beta);
break;
case data_type::u8:
rvv_eltwise_apply_fwd_u8(alg,
reinterpret_cast<const uint8_t *>(src),
reinterpret_cast<uint8_t *>(dst), len, alpha, beta);
break;
default: assert(!"Unsupported data type for RVV eltwise");
}
}
// Data type dispatch for RVV eltwise backward (per-dtype apply)
static inline void compute_eltwise_rvv_bwd(const alg_kind_t alg, void *diff_src,
const void *diff_dst, const void *src, const float alpha,
const float beta, const dim_t len, const data_type_t dt) {
switch (dt) {
case data_type::f32:
rvv_eltwise_apply_bwd_f32(alg, reinterpret_cast<float *>(diff_src),
reinterpret_cast<const float *>(diff_dst),
reinterpret_cast<const float *>(src), len, alpha, beta);
break;
case data_type::f16:
rvv_eltwise_apply_bwd_f16(alg,
reinterpret_cast<_Float16 *>(diff_src),
reinterpret_cast<const _Float16 *>(diff_dst),
reinterpret_cast<const _Float16 *>(src), len, alpha, beta);
break;
case data_type::s32:
rvv_eltwise_apply_bwd_s32(alg,
reinterpret_cast<int32_t *>(diff_src),
reinterpret_cast<const int32_t *>(diff_dst),
reinterpret_cast<const int32_t *>(src), len, alpha, beta);
break;
case data_type::s8:
rvv_eltwise_apply_bwd_s8(alg, reinterpret_cast<int8_t *>(diff_src),
reinterpret_cast<const int8_t *>(diff_dst),
reinterpret_cast<const int8_t *>(src), len, alpha, beta);
break;
case data_type::u8:
rvv_eltwise_apply_bwd_u8(alg, reinterpret_cast<uint8_t *>(diff_src),
reinterpret_cast<const uint8_t *>(diff_dst),
reinterpret_cast<const uint8_t *>(src), len, alpha, beta);
break;
default: assert(!"Unsupported data type for RVV eltwise");
}
}
// Forward execute
template <data_type_t data_type>
status_t rvv_eltwise_fwd_t<data_type>::execute_forward(
const exec_ctx_t &ctx) const {
if (pd()->has_zero_dim_memory()) return status::success;
status_t status = status::success;
auto src = CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto dst = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DST, status);
CHECK(status);
const memory_desc_wrapper data_d(pd()->src_md());
const auto nelems = data_d.nelems(true);
const auto alg_kind = pd()->desc()->alg_kind;
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;
if (pd()->use_dense_) {
src += data_d.offset0();
dst += data_d.offset0();
parallel(0, [&](const int ithr, const int nthr) {
dim_t start = 0, end = 0;
balance211(nelems, nthr, ithr, start, end);
if (start == end) return;
const void *thr_src = static_cast<const void *>(src + start);
void *thr_dst = static_cast<void *>(dst + start);
const dim_t len = end - start;
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
len, pd()->src_md()->data_type);
});
return status::success;
}
// nCspBc padded path: iterate over blocks and handle tail with zero-preserve
if (pd()->use_nCspBc_padded_) {
const blocking_desc_t &blk = data_d.blocking_desc();
const dim_t block = blk.inner_blks[0];
const dim_t MB = pd()->MB();
const dim_t C = pd()->C() / block;
const dim_t C_PADDED = data_d.padded_dims()[1] / block;
const dim_t tail = pd()->C() % block;
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
auto d_off = (n * C_PADDED * SP + c * SP + sp) * block;
const void *thr_src = static_cast<const void *>(src + d_off);
void *thr_dst = static_cast<void *>(dst + d_off);
if (c < C) {
// full block
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha, beta,
block, pd()->src_md()->data_type);
} else {
// tail: process only valid channels, keep padding zero-preserved
const dim_t len = tail;
if (len > 0) {
compute_eltwise_rvv_fwd(alg_kind, thr_src, thr_dst, alpha,
beta, len, pd()->src_md()->data_type);
}
}
});
return status::success;
}
return status::unimplemented;
}
// Backward execute
template <data_type_t data_type>
status_t rvv_eltwise_bwd_t<data_type>::execute_backward(
const exec_ctx_t &ctx) const {
if (pd()->has_zero_dim_memory()) return status::success;
status_t status = status::success;
auto data = pd()->use_dst() ? CTX_IN_MEM(const data_t *, DNNL_ARG_DST)
: CTX_IN_MEM(const data_t *, DNNL_ARG_SRC);
auto diff_dst = CTX_IN_MEM(const data_t *, DNNL_ARG_DIFF_DST);
auto diff_src = CTX_OUT_CLEAN_MEM(data_t *, DNNL_ARG_DIFF_SRC, status);
CHECK(status);
const memory_desc_wrapper data_d(pd()->data_md());
const memory_desc_wrapper diff_d(pd()->diff_src_md());
const auto nelems = diff_d.nelems(true);
const auto alg_kind = pd()->desc()->alg_kind;
const float alpha = pd()->desc()->alpha;
const float beta = pd()->desc()->beta;
if (pd()->use_dense_) {
const dim_t off = diff_d.offset0();
data_t *ds_ptr = diff_src + off;
const data_t *dd_ptr = diff_dst + off;
const data_t *data_ptr = data + off;
parallel(0, [&](const int ithr, const int nthr) {
dim_t start = 0, end = 0;
balance211(nelems, nthr, ithr, start, end);
if (start == end) return;
compute_eltwise_rvv_bwd(alg_kind,
static_cast<void *>(ds_ptr + start),
static_cast<const void *>(dd_ptr + start),
static_cast<const void *>(data_ptr + start), alpha, beta,
end - start, pd()->src_md()->data_type);
});
return status::success;
}
if (pd()->use_nCspBc_padded_) {
const blocking_desc_t &blk = data_d.blocking_desc();
const dim_t block = blk.inner_blks[0];
const dim_t MB = pd()->MB();
const dim_t C = pd()->C() / block;
const dim_t C_PADDED = data_d.padded_dims()[1] / block;
const dim_t tail = pd()->C() % block;
const dim_t SP = pd()->D() * pd()->H() * pd()->W();
parallel_nd(MB, C_PADDED, SP, [&](dim_t n, dim_t c, dim_t sp) {
auto base_off = (n * C_PADDED * SP + c * SP + sp) * block;
auto data_p = data + base_off;
auto dd_p = diff_dst + base_off;
auto ds_p = diff_src + base_off;
const dim_t len = (c < C) ? block : tail;
if (len == 0) return;
compute_eltwise_rvv_bwd(alg_kind, static_cast<void *>(ds_p),
static_cast<const void *>(dd_p),
static_cast<const void *>(data_p), alpha, beta, len,
pd()->src_md()->data_type);
});
return status::success;
}
return status::unimplemented;
}
template struct rvv_eltwise_fwd_t<data_type::f32>;
template struct rvv_eltwise_fwd_t<data_type::f16>;
template struct rvv_eltwise_fwd_t<data_type::s32>;
template struct rvv_eltwise_fwd_t<data_type::s8>;
template struct rvv_eltwise_fwd_t<data_type::u8>;
template struct rvv_eltwise_bwd_t<data_type::f32>;
template struct rvv_eltwise_bwd_t<data_type::f16>;
template struct rvv_eltwise_bwd_t<data_type::s32>;
template struct rvv_eltwise_bwd_t<data_type::s8>;
template struct rvv_eltwise_bwd_t<data_type::u8>;
} // namespace rv64
} // namespace cpu
} // namespace impl
} // namespace dnnl

View File

@ -0,0 +1,184 @@
/*******************************************************************************
* Copyright 2019-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CPU_RV64_RVV_ELTWISE_HPP
#define CPU_RV64_RVV_ELTWISE_HPP
#include <assert.h>
#include "common/c_types_map.hpp"
#include "common/primitive.hpp"
#include "common/type_helpers.hpp"
#include "common/utils.hpp"
#include "cpu/cpu_eltwise_pd.hpp"
#include "cpu/platform.hpp"
namespace dnnl {
namespace impl {
namespace cpu {
namespace rv64 {
// RVV forward eltwise primitive (RV64GCV). Single unified path.
// Key compute kernels are intentionally left for RVV intrinsics implementation.
template <impl::data_type_t data_type>
struct rvv_eltwise_fwd_t : public primitive_t {
struct pd_t : public cpu_eltwise_fwd_pd_t {
using cpu_eltwise_fwd_pd_t::cpu_eltwise_fwd_pd_t;
DECLARE_COMMON_PD_T_("RISCV64GCV", rvv_eltwise_fwd_t);
status_t init(engine_t *engine) {
UNUSED(engine);
using namespace utils;
const memory_desc_wrapper src_d(src_md());
const memory_desc_wrapper dst_d(dst_md());
VDISPATCH_ELTWISE(utils::everyone_is(data_type, src_md()->data_type,
dst_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
// filter out what algs we implemented
VDISPATCH_ELTWISE(
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip,
alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish),
"Unsupported alg_kind for rvv extension");
VDISPATCH_ELTWISE(
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_ELTWISE(
src_d == dst_d, VERBOSE_INCONSISTENT_MDS, "src", "dst");
// Determine supported memory cases: dense or nCspBc padded tail.
use_dense_ = src_d.is_dense(true) && dst_d.is_dense(true)
&& IMPLICATION(!src_d.is_dense() || !dst_d.is_dense(),
is_zero_preserved());
use_nCspBc_padded_ = !use_dense_
&& src_d.blocking_desc().inner_nblks == 1
&& one_of(src_d.blocking_desc().inner_blks[0], 8, 16)
&& src_d.blocking_desc().inner_idxs[0] == 1
&& src_d.only_padded_dim(1) && src_d.is_dense(true);
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
VERBOSE_UNSUPPORTED_SPARSE_CFG);
return status::success;
}
bool use_dense_, use_nCspBc_padded_;
};
rvv_eltwise_fwd_t(const pd_t *apd) : primitive_t(apd) {}
using data_t = typename prec_traits_t<data_type>::type;
status_t execute(const exec_ctx_t &ctx) const override {
return execute_forward(ctx);
}
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
status_t execute_forward(const exec_ctx_t &ctx) const;
};
template <impl::data_type_t data_type>
struct rvv_eltwise_bwd_t : public primitive_t {
struct pd_t : public cpu_eltwise_bwd_pd_t {
using cpu_eltwise_bwd_pd_t::cpu_eltwise_bwd_pd_t;
DECLARE_COMMON_PD_T_("rv64gcv", rvv_eltwise_bwd_t)
status_t init(engine_t *engine) {
UNUSED(engine);
using namespace utils;
using namespace data_type;
const memory_desc_wrapper diff_src_d(diff_src_md());
const memory_desc_wrapper diff_dst_d(diff_dst_md());
const memory_desc_wrapper data_d(data_md());
VDISPATCH_ELTWISE(
utils::everyone_is(data_type, data_md()->data_type,
diff_src_md()->data_type, diff_dst_md()->data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(platform::has_data_type_support(data_type),
VERBOSE_UNSUPPORTED_DT);
VDISPATCH_ELTWISE(
attr()->has_default_values(), VERBOSE_UNSUPPORTED_ATTR);
// filter out what algs we implemented
VDISPATCH_ELTWISE(
utils::one_of(desc()->alg_kind, alg_kind::eltwise_relu,
alg_kind::eltwise_square, alg_kind::eltwise_abs,
alg_kind::eltwise_sqrt, alg_kind::eltwise_linear,
alg_kind::eltwise_clip,
alg_kind::eltwise_hardsigmoid,
alg_kind::eltwise_hardswish),
"Unsupported alg_kind for rvv extension");
VDISPATCH_ELTWISE(
set_default_formats_common(), VERBOSE_UNSUPPORTED_TAG);
VDISPATCH_ELTWISE(diff_dst_d == diff_src_d,
VERBOSE_INCONSISTENT_MDS, "diff_src", "diff_dst");
// Layout support: dense or nCspBc-padded (blocked C with tail)
use_dense_ = diff_dst_d.is_dense()
|| (diff_dst_d.is_dense(true) && is_zero_preserved());
use_nCspBc_padded_ = !use_dense_
&& data_d.blocking_desc().inner_nblks == 1
&& one_of(data_d.blocking_desc().inner_blks[0], 8, 16)
&& data_d.blocking_desc().inner_idxs[0] == 1
&& data_d.only_padded_dim(1) && data_d.is_dense(true);
VDISPATCH_ELTWISE(use_dense_ || use_nCspBc_padded_,
VERBOSE_UNSUPPORTED_SPARSE_CFG);
return status::success;
}
bool use_dense_, use_nCspBc_padded_;
};
rvv_eltwise_bwd_t(const pd_t *apd) : primitive_t(apd) {}
using data_t = typename prec_traits_t<data_type>::type;
status_t execute(const exec_ctx_t &ctx) const override {
return execute_backward(ctx);
}
private:
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
status_t execute_backward(const exec_ctx_t &ctx) const;
};
} // namespace rv64
} // namespace cpu
} // namespace impl
} // namespace dnnl
#endif // CPU_RV64_RVV_ELTWISE_HPP

File diff suppressed because it is too large Load Diff