cpu: x64: matmul: enable GEMV code path to brgemm matmul for avx512

This commit is contained in:
Simonov, Alexander
2025-10-10 11:25:37 -07:00
parent d2b59266e1
commit d5aca6bfda
4 changed files with 66 additions and 27 deletions

View File

@ -736,7 +736,8 @@ status_t brgemm_blocking_tmm(brgemm_desc_t *brg) {
* acc[bd] += dot(a0_reg, x_reg) // Accumulate partial results
*/
status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) {
assert(utils::one_of(brg->isa_impl, avx2, avx2_vnni, avx2_vnni_2));
assert(is_superset(brg->isa_impl, avx2)
&& !is_superset(brg->isa_impl, avx512_core_amx));
assert(brg->load_dim == 1);
brg->ld_block = 1;
@ -749,11 +750,11 @@ status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) {
brg->ldb2_tail = brg->ldb % brg->ld_block2;
assert(brg->ldb2_tail == 0);
brg->bd_block = 8;
brg->bd_block = is_superset(brg->isa_impl, avx512_core) ? 24 : 8;
brg->bdb = brg->bcast_dim / brg->bd_block;
brg->bdb_tail = brg->bcast_dim % brg->bd_block;
const int simd_w = 8;
const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8;
brg->rd_block = simd_w;
brg->rdb = brg->reduce_dim / brg->rd_block;

View File

@ -1123,15 +1123,16 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
for_(dim_t bd = 0; bd < bd_block; bd++)
for (dim_t ld = 0; ld < ld_block2; ld++) {
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
const bool mask_flag = is_tail || brg.is_gemv;
const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask;
auto vmm = accm(ld_block2, bd, ld);
auto ptr_C = ptr[reg_aux_C + C_offset(bd, ld)];
if (use_vadd_for_beta) {
if (brg.is_gemv)
uni_vaddss(Xmm(vmm.getIdx()), Xmm(vmm.getIdx()), ptr_C);
else if (IMPLICATION(
is_tail, is_superset(brg.isa_impl, avx512_core))) {
auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask);
else if (IMPLICATION(mask_flag,
is_superset(brg.isa_impl, avx512_core))) {
auto vmm_masked = vmm_mask(vmm, mask_flag, false, k_mask);
if (brg.is_int8)
uni_vpaddd(vmm_masked, vmm, ptr_C);
else
@ -1145,7 +1146,7 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
}
} else {
const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block;
cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, is_tail, false, k_mask,
cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, mask_flag, false, k_mask,
ld_size);
if (brg.beta == 1.f)
uni_vaddps(vmm, vmm, vmm_prev_dst);
@ -1238,9 +1239,10 @@ void jit_brgemm_kernel_t<Wmm>::apply_post_ops(dim_t bd_block, dim_t ld_block2,
const auto addr = ptr[reg_aux_D + D_offset(bd, ld)];
const auto vmm_prev_dst = vmm_tmp(0);
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
const bool mask_flag = is_tail || brg.is_gemv;
const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask;
const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block;
cvt2ps(brg.sum_dt, vmm_prev_dst, addr, is_tail, false,
cvt2ps(brg.sum_dt, vmm_prev_dst, addr, mask_flag, false,
k_mask, ld_size);
if (p_sum_zp_reg_set)
uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_sum_zp);
@ -1291,7 +1293,7 @@ void jit_brgemm_kernel_t<Wmm>::reduce_gemv_accumulators(
template <typename Wmm>
void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
dim_t ld_block2, dim_t ldb_and_bdb_offset, bool is_ld_tail) {
auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
auto k_mask = (!is_ld_tail && !brg.is_gemv) ? ld_full_mask : ld_tail_mask;
// if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
// accumulated values are already converted to ps in apply_alpha_beta()
@ -1415,7 +1417,8 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
if (brg.with_bias) {
auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)];
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, is_tail, false, k_mask,
const bool mask_flag = is_tail || brg.is_gemv;
cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, mask_flag, false, k_mask,
is_tail ? brg.ldb_tail : brg.ld_block);
}
for (dim_t bd = 0; bd < bd_block; bd++) {
@ -1460,15 +1463,18 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
dim_t zp_c_off = zp_c_values_offset(ld);
const bool mask_flag = is_tail || brg.is_gemv;
if (is_superset(brg.isa_impl, avx512_core)) {
auto zp_c_addr
= EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off);
cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, is_tail, false,
k_mask, is_tail ? brg.ldb_tail : brg.ld_block);
cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, mask_flag,
false, k_mask,
is_tail ? brg.ldb_tail : brg.ld_block);
} else {
cvt2ps(data_type::s32, vmm_zp_c,
ptr[reg_aux_zp_c_values + zp_c_off], is_tail, false,
k_mask, is_tail ? brg.ldb_tail : brg.ld_block);
ptr[reg_aux_zp_c_values + zp_c_off], mask_flag,
false, k_mask,
is_tail ? brg.ldb_tail : brg.ld_block);
}
}
for (dim_t bd = 0; bd < bd_block; bd++) {
@ -1514,11 +1520,12 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
auto vmm_lower = Vmm_lower_t(vmm.getIdx());
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
if (is_superset(brg.isa_impl, avx512_core)) {
const Vmm r_vmm = vmm_mask(vmm, is_tail, true, k_mask);
const bool mask_flag = is_tail || brg.is_gemv;
const Vmm r_vmm = vmm_mask(vmm, mask_flag, true, k_mask);
const Vmm_lower_t r_ymm
= vmm_mask(vmm_lower, is_tail, true, k_mask);
= vmm_mask(vmm_lower, mask_flag, true, k_mask);
const Xmm xmm = Xmm(vmm.getIdx());
const Xmm r_xmm = vmm_mask(xmm, is_tail, true, k_mask);
const Xmm r_xmm = vmm_mask(xmm, mask_flag, true, k_mask);
if (use_sat_cvt) {
assert(one_of(brg.dt_d, data_type::s8, data_type::u8));
auto vmm_perm = vmm_ubound();
@ -2334,9 +2341,14 @@ void jit_brgemm_kernel_t<Wmm>::gemv_microkernel(
? ptr[reg_aux_A + A_offset(row, col)]
: ptr[reg_aux_B + B_offset(row, col)];
if (is_rd_tail)
vmaskmovps(vec, vmm_tail_mask(), addr);
else
if (is_rd_tail) {
if (is_superset(brg.isa_impl, avx512_core)) {
auto vmm_comp_masked
= vmm_mask(vec, is_rd_tail, false, rd_tail_mask);
uni_vmovups(vmm_comp_masked, addr);
} else
vmaskmovps(vec, vmm_tail_mask(), addr);
} else
uni_vmovups(vec, addr);
};
@ -3058,13 +3070,21 @@ void jit_brgemm_kernel_t<Wmm>::generate() {
if (is_superset(brg.isa_impl, avx512_core)) {
const auto full_mask = size_t {0xffffffffffffffff};
const auto tail_mask = size_t((1 << brg.ldb_tail) - 1);
const auto ld_tail = brg.is_gemv ? 1 : brg.ldb_tail;
const auto ld_tail_mask_val = size_t((1 << ld_tail) - 1);
reg64_t reg_mask = rax;
mov(reg_mask, full_mask);
kmovq(ld_full_mask, reg_mask);
mov(reg_mask, tail_mask);
mov(reg_mask, ld_tail_mask_val);
kmovq(ld_tail_mask, reg_mask);
if (brg.is_gemv) {
const auto rd_tail_mask_val
= (static_cast<size_t>(1) << brg.rdb_tail) - 1;
mov(reg_mask, rd_tail_mask_val);
kmovq(rd_tail_mask, reg_mask);
}
}
if (brg.is_int8 && !brg.has_int8_vnni) {

View File

@ -417,7 +417,8 @@ bool is_gemv_applicable(const brgemm_matmul_conf_t &bgmmc,
if (bgmmc.with_reduce) return false;
// BRGEMV currently supports only f32 and AVX2.
if (utils::one_of(false, bm_conf_utils.is_f32(), bgmmc.isa == avx2))
if (utils::one_of(false, bm_conf_utils.is_f32(),
bgmmc.isa == avx2 || bgmmc.isa == avx512_core))
return false;
if (utils::one_of(format_tag::undef, bm_conf_utils.get_gemv_A_tag(A_md),

View File

@ -51,9 +51,26 @@ void horizontal_add_ps(
const Xbyak::Ymm ymm_ws {workspace.getIdx()};
const Xbyak::Ymm ymm_src {src.getIdx()};
code->vextractf64x4(ymm_ws, src, 1);
// Extract upper 256 bits and add to lower 256 bits
code->vextractf32x8(ymm_ws, src, 1);
code->vaddps(ymm_src, ymm_src, ymm_ws);
horizontal_add_ps(code, ymm_src, ymm_ws);
const Xbyak::Xmm xmm_ws {workspace.getIdx()};
const Xbyak::Xmm xmm_src {src.getIdx()};
// Add upper 128 bits to lower 128 bits within the YMM
code->vextractf32x4(xmm_ws, ymm_src, 1);
code->vaddps(xmm_src, xmm_src, xmm_ws);
// Horizontal add within 128 bits - swap 64-bit lanes and add
code->vshufps(
xmm_ws, xmm_src, xmm_src, 0x4E); // swap 64-bit lanes: [2,3,0,1]
code->vaddps(xmm_src, xmm_src, xmm_ws);
// Horizontal add within 64 bits - swap 32-bit elements and add
code->vpshufd(
xmm_ws, xmm_src, 0xB1); // swap adjacent 32-bit elements: [1,0,3,2]
code->vaddps(xmm_src, xmm_src, xmm_ws);
}
} // namespace regops