cpu: x64: matmul: refactor pick_blocked_B_layout function

This commit is contained in:
Denis Samoilov
2025-10-08 13:09:48 -07:00
parent b531c1a4fb
commit 08beeba4c9

View File

@ -621,20 +621,26 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
int n_blk) const {
if (bgmmc.ndims > 3) return format_tag::undef;
if (this->is_int8() || this->is_f8()) switch (n_blk) {
if (is_int8() || is_f8()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
default: return format_tag::undef;
}
}
if (this->is_bf16() || this->is_bf16_with_int_wei()
|| ((this->is_f16() || this->is_f32_f16() || this->is_f32_bf16()
|| this->is_f16_with_int_wei())
&& (is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2))))
const bool is_amx_or_avx2_vnni_2 = is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2);
const bool prefer_amx_or_avx2_vnni_2 = is_f16() || is_f32_f16()
|| is_f32_bf16() || is_f16_with_int_wei();
if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16()
|| is_bf16_with_int_wei()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
@ -642,10 +648,11 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
default: return format_tag::undef;
}
}
// Note: bf32 assumes f32 blocking
if (this->is_f32() || this->is_bf32() || this->is_f16()
|| this->is_f32_f16() || this->is_f32_bf16()
|| this->is_f16_with_int_wei() || this->is_tf32())
if (is_f32() || is_bf32() || is_f16() || is_f32_f16() || is_f32_bf16()
|| is_f16_with_int_wei() || is_tf32()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
@ -653,6 +660,8 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
default: return format_tag::undef;
}
}
return format_tag::undef;
}