cpu: risc-v: Restore specific checks in post_ops_ok for gemm convolution

Co-authored-by: Fei Zhang <zhangfei@iscas.ac.cn>
This commit is contained in:
Xia Zhuozhao
2025-10-15 20:36:55 +08:00
committed by Vadim Pirogov
parent 532111e2d6
commit dca0b6c7d6

View File

@ -77,7 +77,35 @@ struct riscv_gemm_convolution_fwd_t : public primitive_t {
protected:
bool post_ops_ok() const {
return ref_post_ops_t::post_ops_ok(attr()->post_ops_);
auto const &po = attr()->post_ops_;
auto is_sum_ok = [&](int idx) {
return IMPLICATION(po.entry_[idx].kind == primitive_kind::sum,
idx == 0 && po.entry_[idx].is_sum());
};
auto is_binary
= [&](int idx) { return po.entry_[idx].is_binary(); };
auto is_prelu = [&](int idx) { return po.entry_[idx].is_prelu(); };
auto is_binary_or_prelu_supported = [&](int idx) {
bool ok = dnnl::impl::get_rhs_arg_broadcasting_strategy(
binary_injector_utils::get_src1_desc(
po.entry_[idx], dst_md_),
dst_md_,
{broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_oc})
!= broadcasting_strategy_t::unsupported;
return ok;
};
if (!ref_post_ops_t::post_ops_ok(attr()->post_ops_)) return false;
for (int idx = 0; idx < po.len(); idx++) {
bool ok = is_sum_ok(idx)
&& IMPLICATION(is_binary(idx) || is_prelu(idx),
is_binary_or_prelu_supported(idx));
if (!ok) return false;
}
return true;
}
};