mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
cpu: risc-v: Restore specific checks in post_ops_ok for gemm convolution
Co-authored-by: Fei Zhang <zhangfei@iscas.ac.cn>
This commit is contained in:
committed by
Vadim Pirogov
parent
532111e2d6
commit
dca0b6c7d6
@ -77,7 +77,35 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
|
||||
|
||||
protected:
|
||||
bool post_ops_ok() const {
|
||||
return ref_post_ops_t::post_ops_ok(attr()->post_ops_);
|
||||
auto const &po = attr()->post_ops_;
|
||||
auto is_sum_ok = [&](int idx) {
|
||||
return IMPLICATION(po.entry_[idx].kind == primitive_kind::sum,
|
||||
idx == 0 && po.entry_[idx].is_sum());
|
||||
};
|
||||
auto is_binary
|
||||
= [&](int idx) { return po.entry_[idx].is_binary(); };
|
||||
auto is_prelu = [&](int idx) { return po.entry_[idx].is_prelu(); };
|
||||
auto is_binary_or_prelu_supported = [&](int idx) {
|
||||
bool ok = dnnl::impl::get_rhs_arg_broadcasting_strategy(
|
||||
binary_injector_utils::get_src1_desc(
|
||||
po.entry_[idx], dst_md_),
|
||||
dst_md_,
|
||||
{broadcasting_strategy_t::scalar,
|
||||
broadcasting_strategy_t::per_oc})
|
||||
!= broadcasting_strategy_t::unsupported;
|
||||
return ok;
|
||||
};
|
||||
|
||||
if (!ref_post_ops_t::post_ops_ok(attr()->post_ops_)) return false;
|
||||
|
||||
for (int idx = 0; idx < po.len(); idx++) {
|
||||
bool ok = is_sum_ok(idx)
|
||||
&& IMPLICATION(is_binary(idx) || is_prelu(idx),
|
||||
is_binary_or_prelu_supported(idx));
|
||||
if (!ok) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
|
Reference in New Issue
Block a user