mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
cpu: rv64: conv: Refactor validation logic and cleanup
Co-authored-by: Fei Zhang <zhangfei@iscas.ac.cn>
This commit is contained in:
committed by
Vadim Pirogov
parent
fe04323ab0
commit
a4ac6806be
@ -73,7 +73,7 @@ using namespace dnnl::impl::cpu::x64;
|
|||||||
#endif
|
#endif
|
||||||
using namespace dnnl::impl::cpu::aarch64;
|
using namespace dnnl::impl::cpu::aarch64;
|
||||||
#elif DNNL_RV64
|
#elif DNNL_RV64
|
||||||
#if DNNL_RISCV_USE_RVV_INTRINSICS
|
#if defined(DNNL_RISCV_USE_RVV_INTRINSICS)
|
||||||
#include "cpu/rv64/rvv_gemm_convolution.hpp"
|
#include "cpu/rv64/rvv_gemm_convolution.hpp"
|
||||||
using namespace dnnl::impl::cpu::rv64;
|
using namespace dnnl::impl::cpu::rv64;
|
||||||
#endif // DNNL_RISCV_USE_RVV_INTRINSICS
|
#endif // DNNL_RISCV_USE_RVV_INTRINSICS
|
||||||
|
@ -149,7 +149,7 @@ status_t riscv_gemm_convolution_fwd_t::execute_forward_thr_nspc(
|
|||||||
const dim_t LDC = M * jcp.ngroups;
|
const dim_t LDC = M * jcp.ngroups;
|
||||||
const char *BT = jcp.im2col_sz ? "T" : "N";
|
const char *BT = jcp.im2col_sz ? "T" : "N";
|
||||||
const data_t onef = 1.f;
|
const data_t onef = 1.f;
|
||||||
const float beta = this->beta_;
|
const float beta = jcp.with_sum ? 1.0f : 0.0f;
|
||||||
const data_t *__restrict src_od
|
const data_t *__restrict src_od
|
||||||
= src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
|
= src + od * jcp.oh * jcp.ow * jcp.ngroups * jcp.ic;
|
||||||
status_t st = extended_sgemm("N", BT, &M, &N, &K, &onef, wei, &LDA,
|
status_t st = extended_sgemm("N", BT, &M, &N, &K, &onef, wei, &LDA,
|
||||||
@ -318,8 +318,8 @@ status_t riscv_gemm_convolution_fwd_t::execute_forward_ncsp(
|
|||||||
const dim_t LDB = jcp.ic * jcp.ks;
|
const dim_t LDB = jcp.ic * jcp.ks;
|
||||||
const dim_t N = step.oc;
|
const dim_t N = step.oc;
|
||||||
|
|
||||||
// TODO: what if this->beta_ != 0 && != 1 ?
|
const float beta
|
||||||
const float beta = (curr.ic == 0) ? this->beta_ : one;
|
= (curr.ic == 0) ? (jcp.with_sum ? 1.0f : 0.0f) : one;
|
||||||
const float *_source = jcp.im2col_sz
|
const float *_source = jcp.im2col_sz
|
||||||
? _col
|
? _col
|
||||||
: _src + curr.ic * M + curr.od * jcp.os + curr.sp;
|
: _src + curr.ic * M + curr.od * jcp.os + curr.sp;
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
#include "common/c_types_map.hpp"
|
#include "common/c_types_map.hpp"
|
||||||
#include "common/memory_tracking.hpp"
|
#include "common/memory_tracking.hpp"
|
||||||
#include "common/primitive.hpp"
|
#include "common/primitive.hpp"
|
||||||
|
#include "common/utils.hpp"
|
||||||
|
|
||||||
#include "cpu/binary_injector_utils.hpp"
|
#include "cpu/binary_injector_utils.hpp"
|
||||||
#include "cpu/cpu_convolution_pd.hpp"
|
#include "cpu/cpu_convolution_pd.hpp"
|
||||||
@ -44,8 +45,16 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
|
|||||||
using namespace data_type;
|
using namespace data_type;
|
||||||
|
|
||||||
VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND);
|
VDISPATCH_CONV(is_fwd(), VERBOSE_BAD_PROPKIND);
|
||||||
VDISPATCH_CONV(expect_data_types(f32, f32, f32, f32, f32),
|
|
||||||
VERBOSE_UNSUPPORTED_DT_CFG);
|
if (with_bias()) {
|
||||||
|
VDISPATCH_CONV(expect_data_types(f32, f32, f32, f32, f32),
|
||||||
|
VERBOSE_UNSUPPORTED_DT_CFG);
|
||||||
|
} else {
|
||||||
|
VDISPATCH_CONV(
|
||||||
|
expect_data_types(f32, f32, data_type::undef, f32, f32),
|
||||||
|
VERBOSE_UNSUPPORTED_DT_CFG);
|
||||||
|
}
|
||||||
|
|
||||||
VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct),
|
VDISPATCH_CONV(set_default_alg_kind(alg_kind::convolution_direct),
|
||||||
VERBOSE_BAD_ALGORITHM);
|
VERBOSE_BAD_ALGORITHM);
|
||||||
VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
|
VDISPATCH_CONV(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, "");
|
||||||
@ -68,36 +77,7 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
|
|||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool post_ops_ok() const {
|
bool post_ops_ok() const {
|
||||||
auto const &po = attr()->post_ops_;
|
return ref_post_ops_t::post_ops_ok(attr()->post_ops_);
|
||||||
auto is_sum_ok = [&](int idx) {
|
|
||||||
return IMPLICATION(po.entry_[idx].kind == primitive_kind::sum,
|
|
||||||
idx == 0 && po.entry_[idx].is_sum());
|
|
||||||
};
|
|
||||||
auto is_binary
|
|
||||||
= [&](int idx) { return po.entry_[idx].is_binary(); };
|
|
||||||
auto is_prelu = [&](int idx) { return po.entry_[idx].is_prelu(); };
|
|
||||||
auto is_binary_or_prelu_supported = [&](int idx) {
|
|
||||||
bool ok = dnnl::impl::get_rhs_arg_broadcasting_strategy(
|
|
||||||
binary_injector_utils::get_src1_desc(
|
|
||||||
po.entry_[idx], dst_md_),
|
|
||||||
dst_md_,
|
|
||||||
{broadcasting_strategy_t::scalar,
|
|
||||||
broadcasting_strategy_t::per_oc})
|
|
||||||
!= broadcasting_strategy_t::unsupported;
|
|
||||||
return ok;
|
|
||||||
};
|
|
||||||
|
|
||||||
if (!ref_post_ops_t::primitive_kind_ok(attr()->post_ops_))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
for (int idx = 0; idx < po.len(); idx++) {
|
|
||||||
bool ok = is_sum_ok(idx)
|
|
||||||
&& IMPLICATION(is_binary(idx) || is_prelu(idx),
|
|
||||||
is_binary_or_prelu_supported(idx));
|
|
||||||
if (!ok) return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -105,9 +85,7 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
|
|||||||
: primitive_t(apd), post_ops_(nullptr) {}
|
: primitive_t(apd), post_ops_(nullptr) {}
|
||||||
|
|
||||||
status_t init(engine_t *engine) override {
|
status_t init(engine_t *engine) override {
|
||||||
const data_t one = 1.0, zero = 0.0;
|
|
||||||
const auto &jcp = pd()->jcp_;
|
const auto &jcp = pd()->jcp_;
|
||||||
beta_ = jcp.with_sum ? one : zero;
|
|
||||||
|
|
||||||
if (jcp.with_eltwise || jcp.with_binary) {
|
if (jcp.with_eltwise || jcp.with_binary) {
|
||||||
CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops)));
|
CHECK(safe_ptr_assign(post_ops_, new ref_post_ops_t(jcp.post_ops)));
|
||||||
@ -132,8 +110,6 @@ private:
|
|||||||
const memory_tracking::grantor_t &scratchpad) const;
|
const memory_tracking::grantor_t &scratchpad) const;
|
||||||
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
|
const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
|
||||||
|
|
||||||
data_t beta_;
|
|
||||||
|
|
||||||
std::unique_ptr<ref_post_ops_t> post_ops_;
|
std::unique_ptr<ref_post_ops_t> post_ops_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user