mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 10:03:50 +08:00
cpu: x64: matmul: enable GEMV code path to brgemm matmul for avx512
This commit is contained in:
@ -736,7 +736,8 @@ status_t brgemm_blocking_tmm(brgemm_desc_t *brg) {
|
||||
* acc[bd] += dot(a0_reg, x_reg) // Accumulate partial results
|
||||
*/
|
||||
status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) {
|
||||
assert(utils::one_of(brg->isa_impl, avx2, avx2_vnni, avx2_vnni_2));
|
||||
assert(is_superset(brg->isa_impl, avx2)
|
||||
&& !is_superset(brg->isa_impl, avx512_core_amx));
|
||||
assert(brg->load_dim == 1);
|
||||
|
||||
brg->ld_block = 1;
|
||||
@ -749,11 +750,11 @@ status_t brgemm_blocking_vmm_gemv(brgemm_desc_t *brg) {
|
||||
brg->ldb2_tail = brg->ldb % brg->ld_block2;
|
||||
assert(brg->ldb2_tail == 0);
|
||||
|
||||
brg->bd_block = 8;
|
||||
brg->bd_block = is_superset(brg->isa_impl, avx512_core) ? 24 : 8;
|
||||
brg->bdb = brg->bcast_dim / brg->bd_block;
|
||||
brg->bdb_tail = brg->bcast_dim % brg->bd_block;
|
||||
|
||||
const int simd_w = 8;
|
||||
const int simd_w = is_superset(brg->isa_impl, avx512_core) ? 16 : 8;
|
||||
|
||||
brg->rd_block = simd_w;
|
||||
brg->rdb = brg->reduce_dim / brg->rd_block;
|
||||
|
@ -1123,15 +1123,16 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
|
||||
for_(dim_t bd = 0; bd < bd_block; bd++)
|
||||
for (dim_t ld = 0; ld < ld_block2; ld++) {
|
||||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
|
||||
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
|
||||
const bool mask_flag = is_tail || brg.is_gemv;
|
||||
const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask;
|
||||
auto vmm = accm(ld_block2, bd, ld);
|
||||
auto ptr_C = ptr[reg_aux_C + C_offset(bd, ld)];
|
||||
if (use_vadd_for_beta) {
|
||||
if (brg.is_gemv)
|
||||
uni_vaddss(Xmm(vmm.getIdx()), Xmm(vmm.getIdx()), ptr_C);
|
||||
else if (IMPLICATION(
|
||||
is_tail, is_superset(brg.isa_impl, avx512_core))) {
|
||||
auto vmm_masked = vmm_mask(vmm, is_tail, false, k_mask);
|
||||
else if (IMPLICATION(mask_flag,
|
||||
is_superset(brg.isa_impl, avx512_core))) {
|
||||
auto vmm_masked = vmm_mask(vmm, mask_flag, false, k_mask);
|
||||
if (brg.is_int8)
|
||||
uni_vpaddd(vmm_masked, vmm, ptr_C);
|
||||
else
|
||||
@ -1145,7 +1146,7 @@ void jit_brgemm_kernel_t<Wmm>::apply_alpha_beta(
|
||||
}
|
||||
} else {
|
||||
const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block;
|
||||
cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, is_tail, false, k_mask,
|
||||
cvt2ps(brg.dt_c, vmm_prev_dst, ptr_C, mask_flag, false, k_mask,
|
||||
ld_size);
|
||||
if (brg.beta == 1.f)
|
||||
uni_vaddps(vmm, vmm, vmm_prev_dst);
|
||||
@ -1238,9 +1239,10 @@ void jit_brgemm_kernel_t<Wmm>::apply_post_ops(dim_t bd_block, dim_t ld_block2,
|
||||
const auto addr = ptr[reg_aux_D + D_offset(bd, ld)];
|
||||
const auto vmm_prev_dst = vmm_tmp(0);
|
||||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
|
||||
const auto k_mask = is_tail ? ld_tail_mask : ld_full_mask;
|
||||
const bool mask_flag = is_tail || brg.is_gemv;
|
||||
const auto k_mask = mask_flag ? ld_tail_mask : ld_full_mask;
|
||||
const dim_t ld_size = is_tail ? brg.ldb_tail : brg.ld_block;
|
||||
cvt2ps(brg.sum_dt, vmm_prev_dst, addr, is_tail, false,
|
||||
cvt2ps(brg.sum_dt, vmm_prev_dst, addr, mask_flag, false,
|
||||
k_mask, ld_size);
|
||||
if (p_sum_zp_reg_set)
|
||||
uni_vsubps(vmm_prev_dst, vmm_prev_dst, vmm_sum_zp);
|
||||
@ -1291,7 +1293,7 @@ void jit_brgemm_kernel_t<Wmm>::reduce_gemv_accumulators(
|
||||
template <typename Wmm>
|
||||
void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
|
||||
dim_t ld_block2, dim_t ldb_and_bdb_offset, bool is_ld_tail) {
|
||||
auto k_mask = (!is_ld_tail) ? ld_full_mask : ld_tail_mask;
|
||||
auto k_mask = (!is_ld_tail && !brg.is_gemv) ? ld_full_mask : ld_tail_mask;
|
||||
|
||||
// if (brg.is_int8 && alpha_or_beta_applicable && !beta_uses_vadd) ->
|
||||
// accumulated values are already converted to ps in apply_alpha_beta()
|
||||
@ -1415,7 +1417,8 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
|
||||
if (brg.with_bias) {
|
||||
auto ptr_bias = ptr[reg_aux_bias + bias_offset(ld)];
|
||||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
|
||||
cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, is_tail, false, k_mask,
|
||||
const bool mask_flag = is_tail || brg.is_gemv;
|
||||
cvt2ps(brg.dt_bias, vmm_bias, ptr_bias, mask_flag, false, k_mask,
|
||||
is_tail ? brg.ldb_tail : brg.ld_block);
|
||||
}
|
||||
for (dim_t bd = 0; bd < bd_block; bd++) {
|
||||
@ -1460,15 +1463,18 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
|
||||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
|
||||
if (brg.zp_type_c == brgemm_broadcast_t::per_n) {
|
||||
dim_t zp_c_off = zp_c_values_offset(ld);
|
||||
const bool mask_flag = is_tail || brg.is_gemv;
|
||||
if (is_superset(brg.isa_impl, avx512_core)) {
|
||||
auto zp_c_addr
|
||||
= EVEX_compress_addr(reg_aux_zp_c_values, zp_c_off);
|
||||
cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, is_tail, false,
|
||||
k_mask, is_tail ? brg.ldb_tail : brg.ld_block);
|
||||
cvt2ps(data_type::s32, vmm_zp_c, zp_c_addr, mask_flag,
|
||||
false, k_mask,
|
||||
is_tail ? brg.ldb_tail : brg.ld_block);
|
||||
} else {
|
||||
cvt2ps(data_type::s32, vmm_zp_c,
|
||||
ptr[reg_aux_zp_c_values + zp_c_off], is_tail, false,
|
||||
k_mask, is_tail ? brg.ldb_tail : brg.ld_block);
|
||||
ptr[reg_aux_zp_c_values + zp_c_off], mask_flag,
|
||||
false, k_mask,
|
||||
is_tail ? brg.ldb_tail : brg.ld_block);
|
||||
}
|
||||
}
|
||||
for (dim_t bd = 0; bd < bd_block; bd++) {
|
||||
@ -1514,11 +1520,12 @@ void jit_brgemm_kernel_t<Wmm>::store_accumulators_apply_post_ops(dim_t bd_block,
|
||||
auto vmm_lower = Vmm_lower_t(vmm.getIdx());
|
||||
const bool is_tail = is_ld_tail && ld + 1 == ld_block2;
|
||||
if (is_superset(brg.isa_impl, avx512_core)) {
|
||||
const Vmm r_vmm = vmm_mask(vmm, is_tail, true, k_mask);
|
||||
const bool mask_flag = is_tail || brg.is_gemv;
|
||||
const Vmm r_vmm = vmm_mask(vmm, mask_flag, true, k_mask);
|
||||
const Vmm_lower_t r_ymm
|
||||
= vmm_mask(vmm_lower, is_tail, true, k_mask);
|
||||
= vmm_mask(vmm_lower, mask_flag, true, k_mask);
|
||||
const Xmm xmm = Xmm(vmm.getIdx());
|
||||
const Xmm r_xmm = vmm_mask(xmm, is_tail, true, k_mask);
|
||||
const Xmm r_xmm = vmm_mask(xmm, mask_flag, true, k_mask);
|
||||
if (use_sat_cvt) {
|
||||
assert(one_of(brg.dt_d, data_type::s8, data_type::u8));
|
||||
auto vmm_perm = vmm_ubound();
|
||||
@ -2334,9 +2341,14 @@ void jit_brgemm_kernel_t<Wmm>::gemv_microkernel(
|
||||
? ptr[reg_aux_A + A_offset(row, col)]
|
||||
: ptr[reg_aux_B + B_offset(row, col)];
|
||||
|
||||
if (is_rd_tail)
|
||||
vmaskmovps(vec, vmm_tail_mask(), addr);
|
||||
else
|
||||
if (is_rd_tail) {
|
||||
if (is_superset(brg.isa_impl, avx512_core)) {
|
||||
auto vmm_comp_masked
|
||||
= vmm_mask(vec, is_rd_tail, false, rd_tail_mask);
|
||||
uni_vmovups(vmm_comp_masked, addr);
|
||||
} else
|
||||
vmaskmovps(vec, vmm_tail_mask(), addr);
|
||||
} else
|
||||
uni_vmovups(vec, addr);
|
||||
};
|
||||
|
||||
@ -3058,13 +3070,21 @@ void jit_brgemm_kernel_t<Wmm>::generate() {
|
||||
|
||||
if (is_superset(brg.isa_impl, avx512_core)) {
|
||||
const auto full_mask = size_t {0xffffffffffffffff};
|
||||
const auto tail_mask = size_t((1 << brg.ldb_tail) - 1);
|
||||
const auto ld_tail = brg.is_gemv ? 1 : brg.ldb_tail;
|
||||
const auto ld_tail_mask_val = size_t((1 << ld_tail) - 1);
|
||||
reg64_t reg_mask = rax;
|
||||
|
||||
mov(reg_mask, full_mask);
|
||||
kmovq(ld_full_mask, reg_mask);
|
||||
mov(reg_mask, tail_mask);
|
||||
mov(reg_mask, ld_tail_mask_val);
|
||||
kmovq(ld_tail_mask, reg_mask);
|
||||
|
||||
if (brg.is_gemv) {
|
||||
const auto rd_tail_mask_val
|
||||
= (static_cast<size_t>(1) << brg.rdb_tail) - 1;
|
||||
mov(reg_mask, rd_tail_mask_val);
|
||||
kmovq(rd_tail_mask, reg_mask);
|
||||
}
|
||||
}
|
||||
|
||||
if (brg.is_int8 && !brg.has_int8_vnni) {
|
||||
|
@ -417,7 +417,8 @@ bool is_gemv_applicable(const brgemm_matmul_conf_t &bgmmc,
|
||||
if (bgmmc.with_reduce) return false;
|
||||
|
||||
// BRGEMV currently supports only f32 and AVX2.
|
||||
if (utils::one_of(false, bm_conf_utils.is_f32(), bgmmc.isa == avx2))
|
||||
if (utils::one_of(false, bm_conf_utils.is_f32(),
|
||||
bgmmc.isa == avx2 || bgmmc.isa == avx512_core))
|
||||
return false;
|
||||
|
||||
if (utils::one_of(format_tag::undef, bm_conf_utils.get_gemv_A_tag(A_md),
|
||||
|
@ -51,9 +51,26 @@ void horizontal_add_ps(
|
||||
const Xbyak::Ymm ymm_ws {workspace.getIdx()};
|
||||
const Xbyak::Ymm ymm_src {src.getIdx()};
|
||||
|
||||
code->vextractf64x4(ymm_ws, src, 1);
|
||||
// Extract upper 256 bits and add to lower 256 bits
|
||||
code->vextractf32x8(ymm_ws, src, 1);
|
||||
code->vaddps(ymm_src, ymm_src, ymm_ws);
|
||||
horizontal_add_ps(code, ymm_src, ymm_ws);
|
||||
|
||||
const Xbyak::Xmm xmm_ws {workspace.getIdx()};
|
||||
const Xbyak::Xmm xmm_src {src.getIdx()};
|
||||
|
||||
// Add upper 128 bits to lower 128 bits within the YMM
|
||||
code->vextractf32x4(xmm_ws, ymm_src, 1);
|
||||
code->vaddps(xmm_src, xmm_src, xmm_ws);
|
||||
|
||||
// Horizontal add within 128 bits - swap 64-bit lanes and add
|
||||
code->vshufps(
|
||||
xmm_ws, xmm_src, xmm_src, 0x4E); // swap 64-bit lanes: [2,3,0,1]
|
||||
code->vaddps(xmm_src, xmm_src, xmm_ws);
|
||||
|
||||
// Horizontal add within 64 bits - swap 32-bit elements and add
|
||||
code->vpshufd(
|
||||
xmm_ws, xmm_src, 0xB1); // swap adjacent 32-bit elements: [1,0,3,2]
|
||||
code->vaddps(xmm_src, xmm_src, xmm_ws);
|
||||
}
|
||||
|
||||
} // namespace regops
|
||||
|
Reference in New Issue
Block a user