x64: matmul: Enable AVX2 matmul weight decompression

This commit is contained in:
Ovchinnikov Dmitriy
2025-10-07 02:34:38 -07:00
committed by Dmitriy Ovchinnikov
parent 0ba6d85080
commit 846ba7c1d4
3 changed files with 184 additions and 51 deletions

View File

@ -477,6 +477,11 @@ inline bool isa_has_bf16(cpu_isa_t isa) {
return is_superset(isa, avx512_core_bf16);
}
inline bool isa_has_f16(cpu_isa_t isa) {
return is_superset(isa, avx512_core_fp16)
|| is_superset(isa, avx10_1_512_amx_fp16);
}
inline bool isa_has_masks(cpu_isa_t isa) {
return is_superset(isa, avx512_core);
}

View File

@ -2118,7 +2118,7 @@ protected:
* @param zmm Destination ZMM register where data will be inserted
* @param ymm_half Source YMM register containing data for the upper half
*/
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
void copy_half_reg(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
@ -2127,13 +2127,88 @@ protected:
* @param ymm Destination YMM register where data will be inserted
* @param xmm_half Source XMM register containing data for the upper half
*/
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
void copy_half_reg(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
/** Restores vmm register containing permute indices for int4 processing.
* Due to lack of registers this register is used as temporary in `prepare_loaded_int4`.
*/
void restore_vmm_permd(const Ymm &vmm_permd) {
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
= {0, 4, 1, 5, 2, 6, 3, 7};
const auto reg_tmp = r15;
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
vmovdqa(vmm_permd, ptr[reg_tmp]);
}
/** Loaded in register(half of it) containing bytes (2 int4 values)
* The idea is duplicate each byte in two dwords
* using `copy_half_reg` and `vpermd`. Then shift bytes left or right
* depending on the position (odd/even) and signed/unsigned type.
*/
template <typename Vmm>
void prepare_loaded_int4(
const Vmm &reg, const Vmm &vmm_permd, const bool is_signed) {
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
const auto vmm_lower = Vmm_lower_t(reg.getIdx());
copy_half_reg(reg, vmm_lower);
vpermd(reg, vmm_permd, reg);
// Without masks int4 is going to use 2 tmp registers
// To perform masks using VPAND operator
if (!isa_has_masks(conf_->isa)) {
// TODO: Unify register usage over kernels
const auto mask_vmm = Vmm(conf_->transposed_B ? 13 : 0);
const auto tmp_vmm = vmm_permd;
// const auto tmp_vmm2 = Vmm(conf_->transposed_B ? 13: 4);
// f32 and transposed used the same register for regq_tmp
const auto reg_tmp = r15;
alignas(64) static constexpr const uint32_t odd_indices[8] = {
0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff};
alignas(64) static constexpr const uint32_t even_indices[8] = {
0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0};
// Process odd indices
mov(reg_tmp, reinterpret_cast<size_t>(even_indices));
vmovdqa(mask_vmm, ptr[reg_tmp]);
uni_vpand(tmp_vmm, reg, mask_vmm);
uni_vpslld(tmp_vmm, tmp_vmm, 28);
if (is_signed) {
vpsrad(tmp_vmm, tmp_vmm, 28);
} else {
vpsrld(tmp_vmm, tmp_vmm, 28);
}
// Process even indices
mov(reg_tmp, reinterpret_cast<size_t>(odd_indices));
vmovdqa(mask_vmm, ptr[reg_tmp]);
uni_vpand(reg, reg, mask_vmm);
if (is_signed) {
vpsrad(reg, reg, 4);
} else {
vpsrld(reg, reg, 4);
}
// Store result to desired reg
vpaddd(reg, reg, tmp_vmm);
// Clean tmp regs
vpxor(mask_vmm, mask_vmm, mask_vmm);
restore_vmm_permd(vmm_permd);
} else {
uni_vpslld(reg | k5555, reg, 28);
if (is_signed) {
vpsrad(reg | k5555, reg, 28);
vpsrad(reg | kAAAA, reg, 4);
} else {
vpsrld(reg | k5555, reg, 28);
vpsrld(reg | kAAAA, reg, 4);
}
}
}
/**
* @brief Loads and converts data of various types into vector registers with appropriate handling
*
* Integer types s4, u4, s8, u8, s32 will be loaded and converted to s32 during the loading.
* Floating point types f16, bf16, f32 will be loaded and converted to f32 during the loading.
* @tparam Vmm Vector register type for computation
* @param reg Destination vector register
* @param op Source memory operand to load from
@ -2187,28 +2262,18 @@ protected:
uni_vpmovsxbd(
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
op);
copy_half_int4(reg, vmm_lower);
vpermd(reg, vmm_permd, reg);
uni_vpslld(reg | k5555, reg, 28);
vpsrad(reg | k5555, reg, 28);
vpsrad(reg | kAAAA, reg, 4);
prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ true);
break;
case data_type::u4:
uni_vpmovzxbd(
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
op);
copy_half_int4(reg, vmm_lower);
vpermd(reg, vmm_permd, reg);
uni_vpslld(reg | k5555, reg, 28);
vpsrld(reg | k5555, reg, 28);
vpsrld(reg | kAAAA, reg, 4);
prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ false);
break;
case data_type::f4_e2m1:
case data_type::f4_e3m0:
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
copy_half_int4(vmm_in, vmm_lower);
copy_half_reg(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
@ -2415,7 +2480,7 @@ protected:
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
vcvtps2phx(src_vmm_lower0, reg1);
vcvtps2phx(src_vmm_lower1, reg2);
vinsertf64x4(reg1, reg1, src_vmm_lower1, 1);
copy_half_reg(reg1, src_vmm_lower1);
break;
}
case data_type::f32:
@ -3496,9 +3561,11 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset);
if (!isa_has_masks(conf_->isa)) {
if (is_tail)
load_bytes(src_load, load_addr, columns_tail * tr_typesize);
load_bytes(src_load, load_addr,
columns_tail * tr_typesize / elems_per_byte);
else
uni_vmovups(src_load, load_addr);
load_value(src_reg, src_reg, vmm_permd, conf_->orig_wei_dt, false);
} else {
load_value(
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
@ -3867,7 +3934,6 @@ struct jit_brgemm_matmul_copy_b_f32_t
, req_apply_wei_scales_(conf->apply_scales_in_buffer_b)
, typesize_in_(types::data_type_size(dt_in_))
, src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1)
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
, src_stride_(conf_->copy_B_wei_stride)
, tr_src_stride_(conf_->LDB * typesize_out_) {}
@ -3882,10 +3948,11 @@ private:
using opmask_t = const Xbyak::Opmask;
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
const data_type_t dt_in_;
const int simd_w_;
const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_;
const size_t typesize_in_, src_elems_per_byte_, wei_scales_typesize_;
const size_t typesize_in_, src_elems_per_byte_;
const size_t typesize_out_ = sizeof(float);
dim_t src_stride_, tr_src_stride_;
@ -3940,9 +4007,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
const bool is_tail = ncolumns - n < simd_w_;
auto addr = maybe_EVEX_compress_addr(reg_src,
(k * src_stride_ + n * typesize_in_) / src_elems_per_byte_);
if (is_tail && !isa_has_masks(conf_->isa))
vmaskmovps(src_vmm, ymm_tail_mask, addr);
else
if (is_tail && !isa_has_masks(conf_->isa)) {
load_bytes(src_vmm, addr,
(ncolumns % simd_w_) * typesize_in_ / src_elems_per_byte_);
load_value(src_vmm, src_vmm, vmm_permd, conf_->orig_wei_dt, false);
} else
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
@ -3962,6 +4031,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
const auto offset = n * zp_dt_sz / elems_per_byte;
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
if (is_tail && !isa_has_masks(conf_->isa)) {
load_bytes(vmm_zp_b_shift, addr,
(ncolumns % simd_w_) * zp_dt_sz / elems_per_byte);
load_value(vmm_zp_b_shift, vmm_zp_b_shift, vmm_permd, zp_dt, false);
} else
load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail);
};
@ -3977,6 +4051,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
const auto scales_dt_sz = types::data_type_size(scales_dt);
const auto offset = n * scales_dt_sz;
const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset);
if (is_tail && !isa_has_masks(conf_->isa)) {
load_bytes(
vmm_wei_scales, addr, (ncolumns % simd_w_) * scales_dt_sz);
load_scale_value(vmm_wei_scales, vmm_wei_scales, scales_dt, false);
} else
load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail);
};
@ -3990,8 +4069,6 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
kmovw(kTail_int4, tail_mask_4bit);
}
} else {
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail);
}
}
@ -4056,17 +4133,24 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::generate() {
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
kmovw(kFFFF, 0xffff); // 1111111111111111
if (is_src_int4_ || is_src_f4_) {
kmovw(kAAAA, 0xaaaa);
kmovw(k5555, 0x5555);
if (is_superset(conf_->isa, avx512_core)) {
alignas(64) static constexpr const uint32_t int4_permute[16]
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
vmovdqa32(vmm_permd, ptr[reg_tmp]);
kmovw(kAAAA, 0xaaaa);
kmovw(k5555, 0x5555);
} else if (is_superset(conf_->isa, avx2)) {
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
= {0, 4, 1, 5, 2, 6, 3, 7};
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
vmovdqa(vmm_permd, ptr[reg_tmp]);
}
}
if (is_src_f4_) {
if (is_src_f4_ && is_superset(conf_->isa, avx512_core)) {
alignas(64) static constexpr const float f4_e2m1_table[16]
= {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -.5f,
-1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
@ -4347,9 +4431,9 @@ private:
const auto &scales_dt_sz = conf_->wei_scales_dt_sz;
const auto offset = n * scales_dt_sz;
const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail);
const auto addr = EVEX_compress_addr(
const auto addr = maybe_EVEX_compress_addr(
reg_wei_scales, offset, scales_dt == data_type::f16);
vpxord(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales);
uni_vpxor(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales);
switch (scales_dt) {
case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break;
case data_type::bf16:
@ -4383,6 +4467,7 @@ template <typename Vmm>
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
const int columns_tail, const bool use_int4_mask) {
assert(IMPLICATION(use_int4_mask, is_src_int4_));
assert(isa_has_masks(conf_->isa));
if (columns_tail > 0) {
const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_
|| use_bf16_instructions_
@ -4423,7 +4508,7 @@ template <typename Vmm>
bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
const int i, const int columns_tail, const bool is_tail,
const dim_t offset) {
const auto addr = EVEX_compress_addr(reg_src, offset);
const auto addr = maybe_EVEX_compress_addr(reg_src, offset);
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
const auto max_shift_sz = 8;
if (need_preload_int4) {
@ -4437,8 +4522,8 @@ bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
} else {
const auto xmm_tmp = Xmm(tmp_vmm(3).getIdx());
load_bytes(xmm_in, addr, load_sz);
load_bytes(
xmm_tmp, EVEX_compress_addr(reg_src, offset + 1), load_sz);
load_bytes(xmm_tmp, maybe_EVEX_compress_addr(reg_src, offset + 1),
load_sz);
vpsrlq(xmm_in, xmm_in, 4);
vpsllq(xmm_tmp, xmm_tmp, 4);
vpord(xmm_in, xmm_in, xmm_tmp);
@ -4491,12 +4576,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
jg(general_load); // i < dynamic nrows -> general load
// i >= dynamic nrows -> zero out values in src_reg
vpxord(src_reg, src_reg, src_reg);
uni_vpxor(src_reg, src_reg, src_reg);
jmp(load_done);
L(general_load);
} else if (i >= nrows) {
vpxord(src_reg, src_reg, src_reg);
uni_vpxor(src_reg, src_reg, src_reg);
return;
}
@ -4527,7 +4612,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
assert(!"Unsupported data type in loading");
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
vpxord(src_reg_next, src_reg_next, src_reg_next);
uni_vpxor(src_reg_next, src_reg_next, src_reg_next);
} else {
const auto is_tail = columns_tail > 0;
auto src_next_masked = maybe_mask(src_reg_next, is_tail);
@ -4570,12 +4655,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
jg(general_load); // i < dynamic nrows -> general load
// i >= dynamic nrows -> zero out values in src_reg
vpxord(src_reg, src_reg, src_reg);
uni_vpxor(src_reg, src_reg, src_reg);
jmp(load_done);
L(general_load);
} else if (i >= nrows) {
vpxord(src_reg, src_reg, src_reg);
uni_vpxor(src_reg, src_reg, src_reg);
return;
}
@ -4772,7 +4857,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
jg(general_load); // i < dynamic nrows -> general load
// i >= dynamic nrows -> zero out values in src_reg
vpxord(vmm_src, vmm_src, vmm_src);
uni_vpxor(vmm_src, vmm_src, vmm_src);
jmp(load_done);
L(general_load);
@ -4781,7 +4866,34 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
return;
}
if (columns_tail > 0) {
const bool is_tail = columns_tail > 0;
if (conf_->is_f32_with_int_wei) {
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
const auto addr = maybe_EVEX_compress_addr(reg_src, src_offset);
const auto xmm_preload = Xmm(vmm_src.getIdx());
MAYBE_UNUSED(xmm_preload);
const bool preloaded_int4 = preload_int4(
xmm_preload, i, columns_tail, is_tail, src_offset);
const auto &src_op = preloaded_int4
? static_cast<const Xbyak::Operand &>(xmm_preload)
: static_cast<const Xbyak::Operand &>(addr);
if (is_tail && !preloaded_int4) {
load_bytes(vmm_src, addr,
columns_tail * typesize_ / src_elems_per_byte_);
load_value(
vmm_src, vmm_src, vmm_permd, conf_->orig_wei_dt, false);
} else {
load_value(vmm_src, src_op, vmm_permd, conf_->orig_wei_dt,
is_tail);
}
load_zero_point(i, is_tail);
load_scales(i, is_tail);
decompress_reg(
vmm_src, vmm_zp_b_val, vmm_wei_scales, conf_->orig_wei_dt);
} else if (is_tail) {
load_bytes(vmm_src, reg_src, i * src_stride_,
columns_tail * typesize_);
if (use_fp16_instructions_) {
@ -4883,7 +4995,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
const auto src0 = src_vmm(i);
if (do_compute_compensation_)
dot_product(vmm_comp_acc, vmm_comp_mul, src0);
uni_vmovups(ptr[reg_tr_src + i * tr_src_stride_], src0);
const auto addr
= maybe_EVEX_compress_addr(reg_tr_src, i * tr_src_stride_);
if (columns_tail > 0 && i >= columns_tail) break;
uni_vmovups(addr, src0);
}
}
@ -5051,11 +5166,18 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::generate() {
kmovw(k0F0F, 0x0f0f);
kmovw(kF0F0, 0xf0f0);
}
if (is_src_int4_ && is_superset(conf_->isa, avx512_core)) {
if (is_src_int4_) {
if (is_superset(conf_->isa, avx512_core)) {
alignas(64) static constexpr const uint32_t int4_permute[16]
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute));
vmovdqa32(vmm_permd, ptr[regq_tmp]);
} else if (is_superset(conf_->isa, avx2)) {
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
= {0, 4, 1, 5, 2, 6, 3, 7};
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
vmovdqa(vmm_permd, ptr[regq_tmp]);
}
}
load_common_zp_value(vmm_zp_b_val, reg_zp_ptr);

View File

@ -1413,6 +1413,12 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
bgmmc.with_wei_decompression),
VERBOSE_UNSUPPORTED_SCALES_CFG);
// Check if isa has support for f16/bf16 weights scales
VCONDCHECK_BG(IMPLICATION(bgmmc.wei_scales_dt == f16, isa_has_f16(isa))
&& IMPLICATION(
bgmmc.wei_scales_dt == bf16, isa_has_bf16(isa)),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);