cpu: rv64: conv: Refactor validation logic and cleanup

Co-authored-by: Fei Zhang <zhangfei@iscas.ac.cn>
This commit is contained in:
xiazhuozhao
2025-09-15 21:59:32 +08:00
committed by Vadim Pirogov
parent fe04323ab0
commit a4ac6806be
3 changed files with 16 additions and 40 deletions

View File

@ -73,7 +73,7 @@ using namespace dnnl::impl::cpu::x64;
#endif #endif
using namespace dnnl::impl::cpu::aarch64; using namespace dnnl::impl::cpu::aarch64;
#elif DNNL_RV64 #elif DNNL_RV64
#if DNNL_RISCV_USE_RVV_INTRINSICS #if defined(DNNL_RISCV_USE_RVV_INTRINSICS)
#include "cpu/rv64/rvv_gemm_convolution.hpp" #include "cpu/rv64/rvv_gemm_convolution.hpp"
using namespace dnnl::impl::cpu::rv64; using namespace dnnl::impl::cpu::rv64;
#endif // DNNL_RISCV_USE_RVV_INTRINSICS #endif // DNNL_RISCV_USE_RVV_INTRINSICS

View File

@ -149,7 +149,7 @@ status_t riscv_gemm_convolution_fwd_t::execute_forward_thr_nspc(
const dim_t LDC = M * jcp.ngroups; const dim_t LDC = M * jcp.ngroups;
const char *BT = jcp.im2col_sz ? "T" : "N"; const char *BT = jcp.im2col_sz ? "T" : "N";
const data_t onef = 1.f; const data_t onef = 1.f;
const float beta = this->beta_; const float beta = jcp.with_sum ? 1.0f : 0.0f;
const data_t *__restrict src_od const data_t *__restrict src_od
= src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic; = src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
status_t st = extended_sgemm("N", BT, &M, &N, &K, &onef, wei, &LDA, status_t st = extended_sgemm("N", BT, &M, &N, &K, &onef, wei, &LDA,
@ -318,8 +318,8 @@ status_t riscv_gemm_convolution_fwd_t::execute_forward_ncsp(
const dim_t LDB = jcp.ic * jcp.ks; const dim_t LDB = jcp.ic * jcp.ks;
const dim_t N = step.oc; const dim_t N = step.oc;
// TODO: what if this->beta_ != 0 && != 1 ? const float beta
const float beta = (curr.ic == 0) ? this->beta_ : one; = (curr.ic == 0) ? (jcp.with_sum ? 1.0f : 0.0f) : one;
const float *_source = jcp.im2col_sz const float *_source = jcp.im2col_sz
? _col ? _col
: _src + curr.ic * M + curr.od * jcp.os + curr.sp; : _src + curr.ic * M + curr.od * jcp.os + curr.sp;

View File

@ -21,6 +21,7 @@
#include "common/c_types_map.hpp" #include "common/c_types_map.hpp"
#include "common/memory_tracking.hpp" #include "common/memory_tracking.hpp"
#include "common/primitive.hpp" #include "common/primitive.hpp"
#include "common/utils.hpp"
#include "cpu/binary_injector_utils.hpp" #include "cpu/binary_injector_utils.hpp"
#include "cpu/cpu_convolution_pd.hpp" #include "cpu/cpu_convolution_pd.hpp"
@ -44,8 +45,16 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
using namespace data_type; using namespace data_type;
VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND); VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND);
VDISPATCH_CONV(expect_data_types(f32, f32, f32, f32, f32),
VERBOSE_UNSUPPORTED_DT_CFG); if (with_bias()) {
VDISPATCH_CONV(expect_data_types(f32, f32, f32, f32, f32),
VERBOSE_UNSUPPORTED_DT_CFG);
} else {
VDISPATCH_CONV(
expect_data_types(f32, f32, data_type::undef, f32, f32),
VERBOSE_UNSUPPORTED_DT_CFG);
}
VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct), VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct),
VERBOSE_BAD_ALGORITHM); VERBOSE_BAD_ALGORITHM);
VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
@ -68,36 +77,7 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
protected: protected:
bool post_ops_ok() const { bool post_ops_ok() const {
auto const &po = attr()->post_ops_; return ref_post_ops_t::post_ops_ok(attr()->post_ops_);
auto is_sum_ok = [&](int idx) {
return IMPLICATION(po.entry_[idx].kind == primitive_kind::sum,
idx == 0 && po.entry_[idx].is_sum());
};
auto is_binary
= [&](int idx) { return po.entry_[idx].is_binary(); };
auto is_prelu = [&](int idx) { return po.entry_[idx].is_prelu(); };
auto is_binary_or_prelu_supported = [&](int idx) {
bool ok = dnnl::impl::get_rhs_arg_broadcasting_strategy(
binary_injector_utils::get_src1_desc(
po.entry_[idx], dst_md_),
dst_md_,
{broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_oc})
!= broadcasting_strategy_t::unsupported;
return ok;
};
if (!ref_post_ops_t::primitive_kind_ok(attr()->post_ops_))
return false;
for (int idx = 0; idx < po.len(); idx++) {
bool ok = is_sum_ok(idx)
&& IMPLICATION(is_binary(idx) || is_prelu(idx),
is_binary_or_prelu_supported(idx));
if (!ok) return false;
}
return true;
} }
}; };
@ -105,9 +85,7 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
: primitive_t(apd), post_ops_(nullptr) {} : primitive_t(apd), post_ops_(nullptr) {}
status_t init(engine_t *engine) override { status_t init(engine_t *engine) override {
const data_t one = 1.0, zero = 0.0;
const auto &jcp = pd()->jcp_; const auto &jcp = pd()->jcp_;
beta_ = jcp.with_sum ? one : zero;
if (jcp.with_eltwise || jcp.with_binary) { if (jcp.with_eltwise || jcp.with_binary) {
CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops))); CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops)));
@ -132,8 +110,6 @@ private:
const memory_tracking::grantor_t &scratchpad) const; const memory_tracking::grantor_t &scratchpad) const;
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
data_t beta_;
std::unique_ptr<ref_post_ops_t> post_ops_; std::unique_ptr<ref_post_ops_t> post_ops_;
}; };