From 846ba7c1d413e698fc44426f6adfc1e3f37f539f Mon Sep 17 00:00:00 2001 From: Ovchinnikov Dmitriy Date: Tue, 7 Oct 2025 02:34:38 -0700 Subject: [PATCH] x64: matmul: Enable AVX2 matmul weight decompression --- src/cpu/x64/cpu_isa_traits.hpp | 5 + .../x64/matmul/brgemm_matmul_copy_utils.cpp | 224 ++++++++++++++---- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 6 + 3 files changed, 184 insertions(+), 51 deletions(-) diff --git a/src/cpu/x64/cpu_isa_traits.hpp b/src/cpu/x64/cpu_isa_traits.hpp index f5da85d6a5..9226588b4a 100644 --- a/src/cpu/x64/cpu_isa_traits.hpp +++ b/src/cpu/x64/cpu_isa_traits.hpp @@ -477,6 +477,11 @@ inline bool isa_has_bf16(cpu_isa_t isa) { return is_superset(isa, avx512_core_bf16); } +inline bool isa_has_f16(cpu_isa_t isa) { + return is_superset(isa, avx512_core_fp16) + || is_superset(isa, avx10_1_512_amx_fp16); +} + inline bool isa_has_masks(cpu_isa_t isa) { return is_superset(isa, avx512_core); } diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 5bb6220dfa..1c37c100d2 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -2118,7 +2118,7 @@ protected: * @param zmm Destination ZMM register where data will be inserted * @param ymm_half Source YMM register containing data for the upper half */ - void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) { + void copy_half_reg(const Zmm &zmm, const Ymm &ymm_half) { vinserti64x4(zmm, zmm, ymm_half, 1); } @@ -2127,13 +2127,88 @@ protected: * @param ymm Destination YMM register where data will be inserted * @param xmm_half Source XMM register containing data for the upper half */ - void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) { + void copy_half_reg(const Ymm &ymm, const Xmm &xmm_half) { vinserti128(ymm, ymm, xmm_half, 1); } + /** Restores vmm register containing permute indices for int4 processing. + * Due to lack of registers this register is used as temporary in `prepare_loaded_int4`. + */ + void restore_vmm_permd(const Ymm &vmm_permd) { + alignas(64) static constexpr const uint32_t int4_permute_avx2[8] + = {0, 4, 1, 5, 2, 6, 3, 7}; + const auto reg_tmp = r15; + mov(reg_tmp, reinterpret_cast(int4_permute_avx2)); + vmovdqa(vmm_permd, ptr[reg_tmp]); + } + + /** Loaded in register(half of it) containing bytes (2 int4 values) + * The idea is duplicate each byte in two dwords + * using `copy_half_reg` and `vpermd`. Then shift bytes left or right + * depending on the position (odd/even) and signed/unsigned type. + */ + template + void prepare_loaded_int4( + const Vmm ®, const Vmm &vmm_permd, const bool is_signed) { + using Vmm_lower_t = typename vreg_traits_t::Vmm_lower_t; + const auto vmm_lower = Vmm_lower_t(reg.getIdx()); + + copy_half_reg(reg, vmm_lower); + vpermd(reg, vmm_permd, reg); + // Without masks int4 is going to use 2 tmp registers + // To perform masks using VPAND operator + if (!isa_has_masks(conf_->isa)) { + // TODO: Unify register usage over kernels + const auto mask_vmm = Vmm(conf_->transposed_B ? 13 : 0); + const auto tmp_vmm = vmm_permd; + // const auto tmp_vmm2 = Vmm(conf_->transposed_B ? 13: 4); + // f32 and transposed used the same register for regq_tmp + const auto reg_tmp = r15; + alignas(64) static constexpr const uint32_t odd_indices[8] = { + 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff}; + alignas(64) static constexpr const uint32_t even_indices[8] = { + 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0}; + // Process odd indices + mov(reg_tmp, reinterpret_cast(even_indices)); + vmovdqa(mask_vmm, ptr[reg_tmp]); + uni_vpand(tmp_vmm, reg, mask_vmm); + uni_vpslld(tmp_vmm, tmp_vmm, 28); + if (is_signed) { + vpsrad(tmp_vmm, tmp_vmm, 28); + } else { + vpsrld(tmp_vmm, tmp_vmm, 28); + } + + // Process even indices + mov(reg_tmp, reinterpret_cast(odd_indices)); + vmovdqa(mask_vmm, ptr[reg_tmp]); + uni_vpand(reg, reg, mask_vmm); + if (is_signed) { + vpsrad(reg, reg, 4); + } else { + vpsrld(reg, reg, 4); + } + // Store result to desired reg + vpaddd(reg, reg, tmp_vmm); + // Clean tmp regs + vpxor(mask_vmm, mask_vmm, mask_vmm); + restore_vmm_permd(vmm_permd); + } else { + uni_vpslld(reg | k5555, reg, 28); + if (is_signed) { + vpsrad(reg | k5555, reg, 28); + vpsrad(reg | kAAAA, reg, 4); + } else { + vpsrld(reg | k5555, reg, 28); + vpsrld(reg | kAAAA, reg, 4); + } + } + } + /** * @brief Loads and converts data of various types into vector registers with appropriate handling - * + * Integer types s4, u4, s8, u8, s32 will be loaded and converted to s32 during the loading. + * Floating point types f16, bf16, f32 will be loaded and converted to f32 during the loading. * @tparam Vmm Vector register type for computation * @param reg Destination vector register * @param op Source memory operand to load from @@ -2187,28 +2262,18 @@ protected: uni_vpmovsxbd( maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true), op); - copy_half_int4(reg, vmm_lower); - - vpermd(reg, vmm_permd, reg); - uni_vpslld(reg | k5555, reg, 28); - vpsrad(reg | k5555, reg, 28); - vpsrad(reg | kAAAA, reg, 4); + prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ true); break; case data_type::u4: uni_vpmovzxbd( maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true), op); - copy_half_int4(reg, vmm_lower); - - vpermd(reg, vmm_permd, reg); - uni_vpslld(reg | k5555, reg, 28); - vpsrld(reg | k5555, reg, 28); - vpsrld(reg | kAAAA, reg, 4); + prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ false); break; case data_type::f4_e2m1: case data_type::f4_e3m0: uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op); - copy_half_int4(vmm_in, vmm_lower); + copy_half_reg(vmm_in, vmm_lower); vpermd(vmm_in, vmm_permd, vmm_in); uni_vpslld(vmm_in | k5555, vmm_in, 28); vpsrld(vmm_in | k5555, vmm_in, 28); @@ -2415,7 +2480,7 @@ protected: const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx()); vcvtps2phx(src_vmm_lower0, reg1); vcvtps2phx(src_vmm_lower1, reg2); - vinsertf64x4(reg1, reg1, src_vmm_lower1, 1); + copy_half_reg(reg1, src_vmm_lower1); break; } case data_type::f32: @@ -3496,9 +3561,11 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32( auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset); if (!isa_has_masks(conf_->isa)) { if (is_tail) - load_bytes(src_load, load_addr, columns_tail * tr_typesize); + load_bytes(src_load, load_addr, + columns_tail * tr_typesize / elems_per_byte); else uni_vmovups(src_load, load_addr); + load_value(src_reg, src_reg, vmm_permd, conf_->orig_wei_dt, false); } else { load_value( src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail); @@ -3867,7 +3934,6 @@ struct jit_brgemm_matmul_copy_b_f32_t , req_apply_wei_scales_(conf->apply_scales_in_buffer_b) , typesize_in_(types::data_type_size(dt_in_)) , src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1) - , wei_scales_typesize_(conf_->wei_scales_dt_sz) , src_stride_(conf_->copy_B_wei_stride) , tr_src_stride_(conf_->LDB * typesize_out_) {} @@ -3882,10 +3948,11 @@ private: using opmask_t = const Xbyak::Opmask; using Vmm_lower_t = typename vreg_traits_t::Vmm_lower_t; + static constexpr bool is_ymm_ = std::is_same::value; const data_type_t dt_in_; const int simd_w_; const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_; - const size_t typesize_in_, src_elems_per_byte_, wei_scales_typesize_; + const size_t typesize_in_, src_elems_per_byte_; const size_t typesize_out_ = sizeof(float); dim_t src_stride_, tr_src_stride_; @@ -3940,9 +4007,11 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( const bool is_tail = ncolumns - n < simd_w_; auto addr = maybe_EVEX_compress_addr(reg_src, (k * src_stride_ + n * typesize_in_) / src_elems_per_byte_); - if (is_tail && !isa_has_masks(conf_->isa)) - vmaskmovps(src_vmm, ymm_tail_mask, addr); - else + if (is_tail && !isa_has_masks(conf_->isa)) { + load_bytes(src_vmm, addr, + (ncolumns % simd_w_) * typesize_in_ / src_elems_per_byte_); + load_value(src_vmm, src_vmm, vmm_permd, conf_->orig_wei_dt, false); + } else load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail); decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift, @@ -3962,7 +4031,12 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; const auto offset = n * zp_dt_sz / elems_per_byte; const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset); - load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail); + if (is_tail && !isa_has_masks(conf_->isa)) { + load_bytes(vmm_zp_b_shift, addr, + (ncolumns % simd_w_) * zp_dt_sz / elems_per_byte); + load_value(vmm_zp_b_shift, vmm_zp_b_shift, vmm_permd, zp_dt, false); + } else + load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail); }; /** Loads scales, when is_wei_scale_per_n is set. @@ -3977,7 +4051,12 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( const auto scales_dt_sz = types::data_type_size(scales_dt); const auto offset = n * scales_dt_sz; const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset); - load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail); + if (is_tail && !isa_has_masks(conf_->isa)) { + load_bytes( + vmm_wei_scales, addr, (ncolumns % simd_w_) * scales_dt_sz); + load_scale_value(vmm_wei_scales, vmm_wei_scales, scales_dt, false); + } else + load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail); }; const int columns_tail = ncolumns % simd_w_; @@ -3990,8 +4069,6 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( = (1 << (columns_tail / src_elems_per_byte_)) - 1; kmovw(kTail_int4, tail_mask_4bit); } - } else { - init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail); } } @@ -4056,17 +4133,24 @@ void jit_brgemm_matmul_copy_b_f32_t::generate() { mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]); mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]); kmovw(kFFFF, 0xffff); // 1111111111111111 - if (is_src_int4_ || is_src_f4_) { - alignas(64) static constexpr const uint32_t int4_permute[16] - = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; - mov(reg_tmp, reinterpret_cast(int4_permute)); - vmovdqa32(vmm_permd, ptr[reg_tmp]); + if (is_src_int4_ || is_src_f4_) { kmovw(kAAAA, 0xaaaa); kmovw(k5555, 0x5555); + if (is_superset(conf_->isa, avx512_core)) { + alignas(64) static constexpr const uint32_t int4_permute[16] + = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; + mov(reg_tmp, reinterpret_cast(int4_permute)); + vmovdqa32(vmm_permd, ptr[reg_tmp]); + } else if (is_superset(conf_->isa, avx2)) { + alignas(64) static constexpr const uint32_t int4_permute_avx2[8] + = {0, 4, 1, 5, 2, 6, 3, 7}; + mov(reg_tmp, reinterpret_cast(int4_permute_avx2)); + vmovdqa(vmm_permd, ptr[reg_tmp]); + } } - if (is_src_f4_) { + if (is_src_f4_ && is_superset(conf_->isa, avx512_core)) { alignas(64) static constexpr const float f4_e2m1_table[16] = {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f}; @@ -4347,9 +4431,9 @@ private: const auto &scales_dt_sz = conf_->wei_scales_dt_sz; const auto offset = n * scales_dt_sz; const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail); - const auto addr = EVEX_compress_addr( + const auto addr = maybe_EVEX_compress_addr( reg_wei_scales, offset, scales_dt == data_type::f16); - vpxord(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales); + uni_vpxor(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales); switch (scales_dt) { case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break; case data_type::bf16: @@ -4383,6 +4467,7 @@ template void jit_brgemm_matmul_copy_b_transposed_t::init_tail_mask( const int columns_tail, const bool use_int4_mask) { assert(IMPLICATION(use_int4_mask, is_src_int4_)); + assert(isa_has_masks(conf_->isa)); if (columns_tail > 0) { const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_ || use_bf16_instructions_ @@ -4423,7 +4508,7 @@ template bool jit_brgemm_matmul_copy_b_transposed_t::preload_int4(const Xmm &xmm_in, const int i, const int columns_tail, const bool is_tail, const dim_t offset) { - const auto addr = EVEX_compress_addr(reg_src, offset); + const auto addr = maybe_EVEX_compress_addr(reg_src, offset); const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0; const auto max_shift_sz = 8; if (need_preload_int4) { @@ -4437,8 +4522,8 @@ bool jit_brgemm_matmul_copy_b_transposed_t::preload_int4(const Xmm &xmm_in, } else { const auto xmm_tmp = Xmm(tmp_vmm(3).getIdx()); load_bytes(xmm_in, addr, load_sz); - load_bytes( - xmm_tmp, EVEX_compress_addr(reg_src, offset + 1), load_sz); + load_bytes(xmm_tmp, maybe_EVEX_compress_addr(reg_src, offset + 1), + load_sz); vpsrlq(xmm_in, xmm_in, 4); vpsllq(xmm_tmp, xmm_tmp, 4); vpord(xmm_in, xmm_in, xmm_tmp); @@ -4491,12 +4576,12 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( jg(general_load); // i < dynamic nrows -> general load // i >= dynamic nrows -> zero out values in src_reg - vpxord(src_reg, src_reg, src_reg); + uni_vpxor(src_reg, src_reg, src_reg); jmp(load_done); L(general_load); } else if (i >= nrows) { - vpxord(src_reg, src_reg, src_reg); + uni_vpxor(src_reg, src_reg, src_reg); return; } @@ -4527,7 +4612,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( assert(!"Unsupported data type in loading"); if (ncolumns <= req_cvt_bf16_k_blk_step_) { - vpxord(src_reg_next, src_reg_next, src_reg_next); + uni_vpxor(src_reg_next, src_reg_next, src_reg_next); } else { const auto is_tail = columns_tail > 0; auto src_next_masked = maybe_mask(src_reg_next, is_tail); @@ -4570,12 +4655,12 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( jg(general_load); // i < dynamic nrows -> general load // i >= dynamic nrows -> zero out values in src_reg - vpxord(src_reg, src_reg, src_reg); + uni_vpxor(src_reg, src_reg, src_reg); jmp(load_done); L(general_load); } else if (i >= nrows) { - vpxord(src_reg, src_reg, src_reg); + uni_vpxor(src_reg, src_reg, src_reg); return; } @@ -4772,7 +4857,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( jg(general_load); // i < dynamic nrows -> general load // i >= dynamic nrows -> zero out values in src_reg - vpxord(vmm_src, vmm_src, vmm_src); + uni_vpxor(vmm_src, vmm_src, vmm_src); jmp(load_done); L(general_load); @@ -4781,7 +4866,34 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( return; } - if (columns_tail > 0) { + const bool is_tail = columns_tail > 0; + + if (conf_->is_f32_with_int_wei) { + const auto src_offset = (i * src_stride_) / src_elems_per_byte_; + const auto addr = maybe_EVEX_compress_addr(reg_src, src_offset); + const auto xmm_preload = Xmm(vmm_src.getIdx()); + MAYBE_UNUSED(xmm_preload); + const bool preloaded_int4 = preload_int4( + xmm_preload, i, columns_tail, is_tail, src_offset); + const auto &src_op = preloaded_int4 + ? static_cast(xmm_preload) + : static_cast(addr); + + if (is_tail && !preloaded_int4) { + load_bytes(vmm_src, addr, + columns_tail * typesize_ / src_elems_per_byte_); + load_value( + vmm_src, vmm_src, vmm_permd, conf_->orig_wei_dt, false); + } else { + load_value(vmm_src, src_op, vmm_permd, conf_->orig_wei_dt, + is_tail); + } + + load_zero_point(i, is_tail); + load_scales(i, is_tail); + decompress_reg( + vmm_src, vmm_zp_b_val, vmm_wei_scales, conf_->orig_wei_dt); + } else if (is_tail) { load_bytes(vmm_src, reg_src, i * src_stride_, columns_tail * typesize_); if (use_fp16_instructions_) { @@ -4883,7 +4995,10 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( const auto src0 = src_vmm(i); if (do_compute_compensation_) dot_product(vmm_comp_acc, vmm_comp_mul, src0); - uni_vmovups(ptr[reg_tr_src + i * tr_src_stride_], src0); + const auto addr + = maybe_EVEX_compress_addr(reg_tr_src, i * tr_src_stride_); + if (columns_tail > 0 && i >= columns_tail) break; + uni_vmovups(addr, src0); } } @@ -5051,11 +5166,18 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { kmovw(k0F0F, 0x0f0f); kmovw(kF0F0, 0xf0f0); } - if (is_src_int4_ && is_superset(conf_->isa, avx512_core)) { - alignas(64) static constexpr const uint32_t int4_permute[16] - = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; - mov(regq_tmp, reinterpret_cast(int4_permute)); - vmovdqa32(vmm_permd, ptr[regq_tmp]); + if (is_src_int4_) { + if (is_superset(conf_->isa, avx512_core)) { + alignas(64) static constexpr const uint32_t int4_permute[16] + = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; + mov(regq_tmp, reinterpret_cast(int4_permute)); + vmovdqa32(vmm_permd, ptr[regq_tmp]); + } else if (is_superset(conf_->isa, avx2)) { + alignas(64) static constexpr const uint32_t int4_permute_avx2[8] + = {0, 4, 1, 5, 2, 6, 3, 7}; + mov(regq_tmp, reinterpret_cast(int4_permute_avx2)); + vmovdqa(vmm_permd, ptr[regq_tmp]); + } } load_common_zp_value(vmm_zp_b_val, reg_zp_ptr); diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index b8397d865c..9227ab8b67 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1413,6 +1413,12 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, || IMPLICATION(bgmmc.is_wei_scale_per_k, bgmmc.with_wei_decompression), VERBOSE_UNSUPPORTED_SCALES_CFG); + + // Check if isa has support for f16/bf16 weights scales + VCONDCHECK_BG(IMPLICATION(bgmmc.wei_scales_dt == f16, isa_has_f16(isa)) + && IMPLICATION( + bgmmc.wei_scales_dt == bf16, isa_has_bf16(isa)), + VERBOSE_UNSUPPORTED_SCALES_CFG); } const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);