mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
cpu: aarch64: conv: Enable SVE 128 jit uni depthwise convolutions
This commit is contained in:
@ -402,6 +402,7 @@ void jit_uni_dw_conv_fwd_kernel_f32_t<isa>::generate() {
|
||||
|
||||
template struct jit_uni_dw_conv_fwd_kernel_f32_t<sve_512>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_f32_t<sve_256>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_f32_t<sve_128>;
|
||||
|
||||
template <cpu_isa_t isa>
|
||||
inline void jit_uni_dw_conv_bwd_data_kernel_f32_t<isa>::load_ddst(
|
||||
|
@ -107,8 +107,30 @@ status_t jit_uni_dw_conv_fwd_kernel_t<isa, kernel_dt>::init_conf(
|
||||
// Currently this kernel only supports 2D convolutions.
|
||||
if (ndims != 4) return status::unimplemented;
|
||||
|
||||
const auto blocked_tag = isa == sve_512 ? nChw16c : nChw8c;
|
||||
const auto wei_tag = isa == sve_512 ? Goihw16g : Goihw8g;
|
||||
format_tag_t blocked_tag;
|
||||
format_tag_t wei_tag;
|
||||
switch (isa) {
|
||||
case sve_512:
|
||||
blocked_tag = nChw16c;
|
||||
wei_tag = Goihw16g;
|
||||
jcp.ur_w = 6;
|
||||
jcp.nb_ch_blocking = 4;
|
||||
break;
|
||||
case sve_256:
|
||||
blocked_tag = nChw8c;
|
||||
wei_tag = Goihw8g;
|
||||
jcp.ur_w = 4;
|
||||
jcp.nb_ch_blocking = 3;
|
||||
break;
|
||||
case sve_128:
|
||||
blocked_tag = nChw4c;
|
||||
wei_tag = Goihw4g;
|
||||
jcp.ur_w = 8;
|
||||
jcp.nb_ch_blocking = 1;
|
||||
break;
|
||||
default: return status::unimplemented;
|
||||
}
|
||||
|
||||
const auto nxc_tag = nhwc;
|
||||
jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
|
||||
|
||||
@ -147,7 +169,7 @@ status_t jit_uni_dw_conv_fwd_kernel_t<isa, kernel_dt>::init_conf(
|
||||
|
||||
if (!mayiuse(isa)) return status::unimplemented;
|
||||
|
||||
const int simd_w = isa == sve_512 ? 16 : 8;
|
||||
const int simd_w = cpu_isa_traits<isa>::vlen / sizeof(float);
|
||||
jcp.prop_kind = cd.prop_kind;
|
||||
|
||||
const bool with_groups = weights_d.ndims() == src_d.ndims() + 1;
|
||||
@ -189,18 +211,18 @@ status_t jit_uni_dw_conv_fwd_kernel_t<isa, kernel_dt>::init_conf(
|
||||
|| ext_kw <= jcp.r_pad || ext_kh <= jcp.t_pad
|
||||
|| ext_kh <= jcp.b_pad;
|
||||
if (kernel_outside_src) return status::unimplemented;
|
||||
if (isa == sve_128 && jcp.iw == 1)
|
||||
return status::unimplemented; // fallback to brdgemm since it's faster
|
||||
|
||||
jcp.typesize_out = types::data_type_size(dst_d.data_type());
|
||||
jcp.typesize_in = types::data_type_size(src_d.data_type());
|
||||
|
||||
jcp.loop_order = loop_ngcw;
|
||||
|
||||
jcp.ur_w = isa == sve_512 ? 6 : isa == sve_256 ? 4 : 3;
|
||||
jcp.ur_w = nstl::min(jcp.ur_w, jcp.ow);
|
||||
|
||||
jcp.ch_block = simd_w;
|
||||
jcp.nb_ch = div_up(jcp.oc, jcp.ch_block);
|
||||
jcp.nb_ch_blocking = isa == sve_512 ? 4 : isa == sve_256 ? 3 : 2;
|
||||
if (jcp.nb_ch < jcp.nb_ch_blocking) jcp.nb_ch_blocking = jcp.nb_ch;
|
||||
|
||||
if (is_data_layout_nxc) {
|
||||
@ -258,7 +280,8 @@ status_t jit_uni_dw_conv_fwd_kernel_t<isa, kernel_dt>::init_conf(
|
||||
if (dst_d.data_type() == data_type::s32) return status::unimplemented;
|
||||
}
|
||||
bool ok_to_pad_channels = true && jcp.oc == jcp.ngroups
|
||||
&& jcp.ic == jcp.ngroups && (isa == sve_256 || isa == sve_512);
|
||||
&& jcp.ic == jcp.ngroups
|
||||
&& (utils::one_of(isa, sve_128, sve_256, sve_512));
|
||||
if (ok_to_pad_channels) {
|
||||
jcp.oc = rnd_up(jcp.oc, simd_w);
|
||||
jcp.ic = rnd_up(jcp.oc, simd_w);
|
||||
@ -290,7 +313,9 @@ void jit_uni_dw_conv_fwd_kernel_t<isa, kernel_dt>::init_scratchpad(
|
||||
|
||||
template struct jit_uni_dw_conv_fwd_kernel_t<sve_512, data_type::f32>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_t<sve_256, data_type::f32>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_t<sve_128, data_type::f32>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_t<sve_256, data_type::bf16>;
|
||||
template struct jit_uni_dw_conv_fwd_kernel_t<sve_128, data_type::bf16>;
|
||||
|
||||
template <cpu_isa_t isa, data_type_t kernel_dt>
|
||||
struct jit_uni_dw_conv_bwd_data_kernel_t {
|
||||
|
@ -162,7 +162,9 @@ void jit_uni_dw_convolution_fwd_t<isa, src_type, dst_type>::execute_forward(
|
||||
|
||||
template struct jit_uni_dw_convolution_fwd_t<sve_512, data_type::f32>;
|
||||
template struct jit_uni_dw_convolution_fwd_t<sve_256, data_type::f32>;
|
||||
template struct jit_uni_dw_convolution_fwd_t<sve_128, data_type::f32>;
|
||||
template struct jit_uni_dw_convolution_fwd_t<sve_256, data_type::bf16>;
|
||||
template struct jit_uni_dw_convolution_fwd_t<sve_128, data_type::bf16>;
|
||||
|
||||
template <cpu_isa_t isa, data_type_t diff_dst_type, data_type_t diff_src_type>
|
||||
void jit_uni_dw_convolution_bwd_data_t<isa, diff_dst_type,
|
||||
|
@ -105,6 +105,8 @@ using jit_sve_256_dw_convolution_fwd_t
|
||||
= jit_uni_dw_convolution_fwd_t<sve_256, data_type::f32>;
|
||||
using jit_sve_256_dw_convolution_bf16_fwd_t
|
||||
= jit_uni_dw_convolution_fwd_t<sve_256, data_type::bf16>;
|
||||
using jit_sve_128_dw_convolution_bf16_fwd_t
|
||||
= jit_uni_dw_convolution_fwd_t<sve_128, data_type::bf16>;
|
||||
|
||||
template <cpu_isa_t isa, data_type_t diff_dst_type,
|
||||
data_type_t diff_src_type = diff_dst_type>
|
||||
|
@ -152,6 +152,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
|
||||
CPU_INSTANCE_AARCH64(jit_sve_1x1_convolution_fwd_t<f32,f32,f32,sve_256>)
|
||||
CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t<f32,f32,f32,sve_256>)
|
||||
CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t<sve_256>)
|
||||
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t<sve_128,data_type::f32>)
|
||||
CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t<sve_128>)
|
||||
CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t)
|
||||
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
|
||||
@ -225,6 +226,7 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
|
||||
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t<sve_256, bf16, bf16>)
|
||||
CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t<sve_256>)
|
||||
CPU_INSTANCE_AARCH64(brdgmm_dw_convolution_fwd_t<sve_128>)
|
||||
CPU_INSTANCE_AARCH64(jit_uni_dw_convolution_fwd_t<sve_128, bf16, bf16>)
|
||||
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
|
||||
CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t<sve_256>)
|
||||
CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t<sve_256>)
|
||||
|
Reference in New Issue
Block a user