diff --git a/src/cpu/x64/brgemm/brgemm_utils.cpp b/src/cpu/x64/brgemm/brgemm_utils.cpp index a97a778a1b..b8272e6e9d 100644 --- a/src/cpu/x64/brgemm/brgemm_utils.cpp +++ b/src/cpu/x64/brgemm/brgemm_utils.cpp @@ -736,7 +736,8 @@ status_t brgemm_blocking_tmm(brgemm_desc_t *brg) { * acc[bd] += dot(a0_reg, x_reg) // Accumulate partial results */ status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) { - assert(utils::one_of(brg->isa_impl, avx2, avx2_vnni, avx2_vnni_2)); + assert(is_superset(brg->isa_impl, avx2) + && !is_superset(brg->isa_impl, avx512_core_amx)); assert(brg->load_dim == 1); brg->ld_block = 1; @@ -749,11 +750,11 @@ status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) { brg->ldb2_tail = brg->ldb % brg->ld_block2; assert(brg->ldb2_tail == 0); - brg->bd_block = 8; + brg->bd_block = is_superset(brg->isa_impl, avx512_core) ? 24 : 8; brg->bdb = brg->bcast_dim / brg->bd_block; brg->bdb_tail = brg->bcast_dim % brg->bd_block; - const int simd_w = 8; + const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8; brg->rd_block = simd_w; brg->rdb = brg->reduce_dim / brg->rd_block; diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index 9ebe2f4203..426edf9459 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -1123,15 +1123,16 @@ void jit_brgemm_kernel_t::apply_alpha_beta( for_(dim_t bd = 0; bd < bd_block; bd++) for (dim_t ld = 0; ld < ld_block2; ld++) { const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; + const bool mask_flag = is_tail || brg.is_gemv; + const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask; auto vmm = accm(ld_block2, bd, ld); auto ptr_C = ptr[reg_aux_C + C_offset(bd, ld)]; if (use_vadd_for_beta) { if (brg.is_gemv) uni_vaddss(Xmm(vmm.getIdx()), Xmm(vmm.getIdx()), ptr_C); - else if (IMPLICATION( - is_tail, is_superset(brg.isa_impl, avx512_core))) { - auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask); + else if (IMPLICATION(mask_flag, + is_superset(brg.isa_impl, avx512_core))) { + auto vmm_masked = vmm_mask(vmm, mask_flag, false, k_mask); if (brg.is_int8) uni_vpaddd(vmm_masked, vmm, ptr_C); else @@ -1145,7 +1146,7 @@ void jit_brgemm_kernel_t::apply_alpha_beta( } } else { const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block; - cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, is_tail, false, k_mask, + cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, mask_flag, false, k_mask, ld_size); if (brg.beta == 1.f) uni_vaddps(vmm, vmm, vmm_prev_dst); @@ -1238,9 +1239,10 @@ void jit_brgemm_kernel_t::apply_post_ops(dim_t bd_block, dim_t ld_block2, const auto addr = ptr[reg_aux_D + D_offset(bd, ld)]; const auto vmm_prev_dst = vmm_tmp(0); const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask; + const bool mask_flag = is_tail || brg.is_gemv; + const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask; const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block; - cvt2ps(brg.sum_dt, vmm_prev_dst, addr, is_tail, false, + cvt2ps(brg.sum_dt, vmm_prev_dst, addr, mask_flag, false, k_mask, ld_size); if (p_sum_zp_reg_set) uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_sum_zp); @@ -1291,7 +1293,7 @@ void jit_brgemm_kernel_t::reduce_gemv_accumulators( template void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(dim_t bd_block, dim_t ld_block2, dim_t ldb_and_bdb_offset, bool is_ld_tail) { - auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask; + auto k_mask = (!is_ld_tail && !brg.is_gemv) ? ld_full_mask : ld_tail_mask; // if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) -> // accumulated values are already converted to ps in apply_alpha_beta() @@ -1415,7 +1417,8 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(dim_t bd_block, if (brg.with_bias) { auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)]; const bool is_tail = is_ld_tail && ld + 1 == ld_block2; - cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, is_tail, false, k_mask, + const bool mask_flag = is_tail || brg.is_gemv; + cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, mask_flag, false, k_mask, is_tail ? brg.ldb_tail : brg.ld_block); } for (dim_t bd = 0; bd < bd_block; bd++) { @@ -1460,15 +1463,18 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(dim_t bd_block, const bool is_tail = is_ld_tail && ld + 1 == ld_block2; if (brg.zp_type_c == brgemm_broadcast_t::per_n) { dim_t zp_c_off = zp_c_values_offset(ld); + const bool mask_flag = is_tail || brg.is_gemv; if (is_superset(brg.isa_impl, avx512_core)) { auto zp_c_addr = EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off); - cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, is_tail, false, - k_mask, is_tail ? brg.ldb_tail : brg.ld_block); + cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, mask_flag, + false, k_mask, + is_tail ? brg.ldb_tail : brg.ld_block); } else { cvt2ps(data_type::s32, vmm_zp_c, - ptr[reg_aux_zp_c_values + zp_c_off], is_tail, false, - k_mask, is_tail ? brg.ldb_tail : brg.ld_block); + ptr[reg_aux_zp_c_values + zp_c_off], mask_flag, + false, k_mask, + is_tail ? brg.ldb_tail : brg.ld_block); } } for (dim_t bd = 0; bd < bd_block; bd++) { @@ -1514,11 +1520,12 @@ void jit_brgemm_kernel_t::store_accumulators_apply_post_ops(dim_t bd_block, auto vmm_lower = Vmm_lower_t(vmm.getIdx()); const bool is_tail = is_ld_tail && ld + 1 == ld_block2; if (is_superset(brg.isa_impl, avx512_core)) { - const Vmm r_vmm = vmm_mask(vmm, is_tail, true, k_mask); + const bool mask_flag = is_tail || brg.is_gemv; + const Vmm r_vmm = vmm_mask(vmm, mask_flag, true, k_mask); const Vmm_lower_t r_ymm - = vmm_mask(vmm_lower, is_tail, true, k_mask); + = vmm_mask(vmm_lower, mask_flag, true, k_mask); const Xmm xmm = Xmm(vmm.getIdx()); - const Xmm r_xmm = vmm_mask(xmm, is_tail, true, k_mask); + const Xmm r_xmm = vmm_mask(xmm, mask_flag, true, k_mask); if (use_sat_cvt) { assert(one_of(brg.dt_d, data_type::s8, data_type::u8)); auto vmm_perm = vmm_ubound(); @@ -2334,9 +2341,14 @@ void jit_brgemm_kernel_t::gemv_microkernel( ? ptr[reg_aux_A + A_offset(row, col)] : ptr[reg_aux_B + B_offset(row, col)]; - if (is_rd_tail) - vmaskmovps(vec, vmm_tail_mask(), addr); - else + if (is_rd_tail) { + if (is_superset(brg.isa_impl, avx512_core)) { + auto vmm_comp_masked + = vmm_mask(vec, is_rd_tail, false, rd_tail_mask); + uni_vmovups(vmm_comp_masked, addr); + } else + vmaskmovps(vec, vmm_tail_mask(), addr); + } else uni_vmovups(vec, addr); }; @@ -3058,13 +3070,21 @@ void jit_brgemm_kernel_t::generate() { if (is_superset(brg.isa_impl, avx512_core)) { const auto full_mask = size_t {0xffffffffffffffff}; - const auto tail_mask = size_t((1 << brg.ldb_tail) - 1); + const auto ld_tail = brg.is_gemv ? 1 : brg.ldb_tail; + const auto ld_tail_mask_val = size_t((1 << ld_tail) - 1); reg64_t reg_mask = rax; mov(reg_mask, full_mask); kmovq(ld_full_mask, reg_mask); - mov(reg_mask, tail_mask); + mov(reg_mask, ld_tail_mask_val); kmovq(ld_tail_mask, reg_mask); + + if (brg.is_gemv) { + const auto rd_tail_mask_val + = (static_cast(1) << brg.rdb_tail) - 1; + mov(reg_mask, rd_tail_mask_val); + kmovq(rd_tail_mask, reg_mask); + } } if (brg.is_int8 && !brg.has_int8_vnni) { diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 01864e3368..a4c232aa68 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -417,7 +417,8 @@ bool is_gemv_applicable(const brgemm_matmul_conf_t &bgmmc, if (bgmmc.with_reduce) return false; // BRGEMV currently supports only f32 and AVX2. - if (utils::one_of(false, bm_conf_utils.is_f32(), bgmmc.isa == avx2)) + if (utils::one_of(false, bm_conf_utils.is_f32(), + bgmmc.isa == avx2 || bgmmc.isa == avx512_core)) return false; if (utils::one_of(format_tag::undef, bm_conf_utils.get_gemv_A_tag(A_md), diff --git a/src/cpu/x64/utils/jit_regops.cpp b/src/cpu/x64/utils/jit_regops.cpp index bee029df47..fed4998a91 100644 --- a/src/cpu/x64/utils/jit_regops.cpp +++ b/src/cpu/x64/utils/jit_regops.cpp @@ -51,9 +51,26 @@ void horizontal_add_ps( const Xbyak::Ymm ymm_ws {workspace.getIdx()}; const Xbyak::Ymm ymm_src {src.getIdx()}; - code->vextractf64x4(ymm_ws, src, 1); + // Extract upper 256 bits and add to lower 256 bits + code->vextractf32x8(ymm_ws, src, 1); code->vaddps(ymm_src, ymm_src, ymm_ws); - horizontal_add_ps(code, ymm_src, ymm_ws); + + const Xbyak::Xmm xmm_ws {workspace.getIdx()}; + const Xbyak::Xmm xmm_src {src.getIdx()}; + + // Add upper 128 bits to lower 128 bits within the YMM + code->vextractf32x4(xmm_ws, ymm_src, 1); + code->vaddps(xmm_src, xmm_src, xmm_ws); + + // Horizontal add within 128 bits - swap 64-bit lanes and add + code->vshufps( + xmm_ws, xmm_src, xmm_src, 0x4E); // swap 64-bit lanes: [2,3,0,1] + code->vaddps(xmm_src, xmm_src, xmm_ws); + + // Horizontal add within 64 bits - swap 32-bit elements and add + code->vpshufd( + xmm_ws, xmm_src, 0xB1); // swap adjacent 32-bit elements: [1,0,3,2] + code->vaddps(xmm_src, xmm_src, xmm_ws); } } // namespace regops