x64: matmul: Refactoring matmul copy kernels

This commit is contained in:
Ovchinnikov Dmitriy
2025-10-07 04:37:20 -07:00
committed by Dmitriy Ovchinnikov
parent 794caa04ba
commit f92a20983c

View File

@ -2063,6 +2063,327 @@ void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::generate() {
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>; template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>;
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>; template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>;
/**
* @brief Common class for BRGEMM B matrix copy operations
*
* This class contains common methods and properties for all copy B kernels.
* Now it consists of `load_value` and `decompress_reg` and it's considered
* to contain all common methods for copy B kernels.
*/
struct jit_brgemm_matmul_copy_b_common_t : public jit_brgemm_matmul_copy_b_t,
public jit_generator_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_common_t)
jit_brgemm_matmul_copy_b_common_t(const brgemm_matmul_conf_t *conf)
: jit_brgemm_matmul_copy_b_t(conf), jit_generator_t(jit_name()) {}
protected:
/**
* @brief Conditionally applies a mask to a vector register for tail or specialized processing
*
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param vmm The vector register to potentially mask
* @param is_tail Flag indicating if this is tail processing that requires masking
* @param is_int4 Flag indicating if this is INT4 data requiring specialized masking
* @return The potentially masked vector register (original if no masking needed)
*/
template <typename Vmm>
Vmm maybe_mask(Vmm vmm, bool is_tail, bool is_int4 = false) {
// Transposed kernel uses the same kTail mask for both cases
const auto tail_mask
= is_int4 && !conf_->transposed_B ? kTail_int4 : kTail;
const auto unmask_tail
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16)
&& !conf_->transposed_B;
if (isa_has_masks(conf_->isa))
return is_tail ? vmm | tail_mask | T_z
// bf16 and f16 requires masking for tail
// to avoid AVX512F issues with zeroing upper bits
// of ZMM registers when using vpmovzxwd/vpmovzxbd
// instructions
: unmask_tail ? vmm | kFFFF | T_z
: vmm;
else
return vmm;
}
/**
* @brief Inserts a YMM register into the upper half of a ZMM register
* @param zmm Destination ZMM register where data will be inserted
* @param ymm_half Source YMM register containing data for the upper half
*/
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
/**
* @brief Inserts an XMM register into the upper half of a YMM register
* @param ymm Destination YMM register where data will be inserted
* @param xmm_half Source XMM register containing data for the upper half
*/
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
/**
* @brief Loads and converts data of various types into vector registers with appropriate handling
*
* @tparam Vmm Vector register type for computation
* @param reg Destination vector register
* @param op Source memory operand to load from
* @param vmm_permd Vector register containing permutation indices for INT4 processing
* @param dt Data type being loaded
* @param is_tail Flag indicating if tail processing is needed
*/
template <typename Vmm>
void load_value(const Vmm &reg, const Xbyak::Operand &op,
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false) {
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
const auto vmm_in = maybe_mask(reg, is_tail);
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
MAYBE_UNUSED(vmm_lower);
const bool is_xf16
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16);
switch (dt) {
case data_type::s32: vmovdqu32(vmm_in, op); break;
case data_type::f32: {
if (conf_->transposed_B)
vmovdqu8(vmm_in, op);
else
uni_vmovups(vmm_in, op);
break;
}
case data_type::f16:
if (!is_xf16) {
uni_vcvtph2psx(vmm_in, op);
break;
}
case data_type::bf16:
if (is_xf16) {
vmovdqu16(vmm_in, op);
} else {
uni_vpmovzxwd(vmm_in, op);
uni_vpslld(vmm_in, vmm_in, 16);
}
break;
case data_type::s8: uni_vpmovsxbd(vmm_in, op); break;
case data_type::u8: uni_vpmovzxbd(vmm_in, op); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
uni_vpmovsxbd(
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
op);
copy_half_int4(reg, vmm_lower);
vpermd(reg, vmm_permd, reg);
uni_vpslld(reg | k5555, reg, 28);
vpsrad(reg | k5555, reg, 28);
vpsrad(reg | kAAAA, reg, 4);
break;
case data_type::u4:
uni_vpmovzxbd(
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
op);
copy_half_int4(reg, vmm_lower);
vpermd(reg, vmm_permd, reg);
uni_vpslld(reg | k5555, reg, 28);
vpsrld(reg | k5555, reg, 28);
vpsrld(reg | kAAAA, reg, 4);
break;
default: assert(!"unsupported data type");
}
}
/**
* @brief Applies zero point shift to vector register
* Shifts input values by subtracting zero point values.
*
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param input Vector register containing input values to be shifted
* @param zp Vector register containing zero point values to be subtracted
* @param src_dt Source data type that determines which instruction to use
*/
template <typename Vmm>
void apply_shift(const Vmm &input, const Vmm &zp, data_type_t src_dt) {
if (!conf_->has_zero_point_b) return;
const bool is_int_shift = one_of(src_dt, data_type::s8, data_type::u8,
data_type::s4, data_type::u4);
const bool is_fp_shift = one_of(
src_dt, data_type::bf16, data_type::f16, data_type::f32);
if (is_int_shift)
uni_vpsubd(input, input, zp);
else if (is_fp_shift)
uni_vsubps(input, input, zp);
// should be unreachable
else
assert(!"Unable to shift input for zero-point");
}
/**
* @brief Converts integer data types to floating-point format
*
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param input Vector register containing data to be converted
* @param src_dt Source data type of the values to be converted
*/
template <typename Vmm>
void upconvert_to_f32(const Vmm &input, data_type_t src_dt) {
switch (src_dt) {
case data_type::s8:
case data_type::u8:
case data_type::s4:
case data_type::u4: uni_vcvtdq2ps(input, input); break;
case data_type::bf16:
case data_type::f16:
case data_type::f32:
// bf16 and f16 already converted into f32 while loading
break;
default: assert(!"Unsupported source data type for decompression");
}
}
/**
* @brief Applies scale factors to input values when configured
* The operation is only performed if apply_scales_in_buffer_b flag is set
* in the configuration.
*
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param input Vector register containing values to be scaled
* @param scale_op Operand containing scale factors to apply
*/
template <typename Vmm>
void apply_scales(const Vmm &input, const Xbyak::Operand &scale_op) {
if (conf_->apply_scales_in_buffer_b) uni_vmulps(input, input, scale_op);
}
/**
* @brief Converts floating-point data from F32 to specified destination data type
*
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param input Vector register containing F32 values to be converted
* @param dst_dt Destination data type for conversion
*/
template <typename Vmm>
void downconvert_to_dst_dt(const Vmm &input, data_type_t dst_dt) {
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
switch (dst_dt) {
case data_type::bf16:
vcvtneps2bf16(Ymm(input.getIdx()), input);
break;
case data_type::f16:
vcvtps2phx(Vmm_lower_t(input.getIdx()), input);
break;
case data_type::f32:
// f32 is already in the correct format
break;
default:
assert(!"Unsupported destination data type for decompression");
}
}
template <typename Vmm>
void downconvert_to_dst_dt(
const Vmm &reg1, const Vmm &reg2, data_type_t dst_dt) {
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
switch (dst_dt) {
case data_type::bf16: vcvtne2ps2bf16(reg1, reg2, reg1); break;
case data_type::f16: {
const auto src_vmm_lower0 = Vmm_lower_t(reg1.getIdx());
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
vcvtps2phx(src_vmm_lower0, reg1);
vcvtps2phx(src_vmm_lower1, reg2);
vinsertf64x4(reg1, reg1, src_vmm_lower1, 1);
break;
}
case data_type::f32:
// f32 is already in the correct format
break;
default:
assert(!"Unsupported destination data type for decompression");
}
}
/**
* @brief Decompresses values by applying zero point shift, data type conversion, and scaling.
* @tparam Vmm Vector register type used for computation
* @param input Vector register containing input values to be decompressed
* @param zp Vector register containing zero point values
* @param scale_op Operand containing scales to be applied
* @param src_dt Source data type of the values being decompressed
* @param dst_dt Destination data type for the decompressed values
*/
template <typename Vmm>
void decompress_reg(const Vmm &input, const Vmm &zp,
const Xbyak::Operand &scale_op, data_type_t src_dt) {
if (src_dt == data_type::f32)
return; // Decompression doesn't support f32
apply_shift(input, zp, src_dt);
upconvert_to_f32(input, src_dt);
apply_scales(input, scale_op);
}
template <typename Vmm>
void decompress_and_downcvt_reg(const Vmm &input, const Vmm &zp,
const Xbyak::Operand &scale_op, data_type_t src_dt,
data_type_t dst_dt) {
if (src_dt == dst_dt) return;
decompress_reg(input, zp, scale_op, src_dt);
downconvert_to_dst_dt(input, dst_dt);
}
template <typename Vmm>
void decompress_and_downcvt_2reg(const Vmm &input1, const Vmm &input2,
const Vmm &zp1, const Vmm &zp2, const Xbyak::Operand &scale_op1,
const Xbyak::Operand &scale_op2, data_type_t src_dt,
data_type_t dst_dt) {
if (src_dt == dst_dt) return;
decompress_reg(input1, zp1, scale_op1, src_dt);
decompress_reg(input2, zp2, scale_op2, src_dt);
downconvert_to_dst_dt(input1, input2, dst_dt);
}
/** @brief Helper method to load scales into vector register.
* Supports f32, bf16 and f16 data types.
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
* @param vmm Vector register to load scale value into
* @param op Operand to load scale value from
**/
template <typename Vmm>
void load_scale_value(const Vmm &vmm, const Xbyak::Operand &op,
data_type_t dt, bool is_tail = false) {
const auto masked_vmm = maybe_mask(vmm, is_tail);
switch (dt) {
case data_type::f32: uni_vmovups(masked_vmm, op); break;
case data_type::bf16:
uni_vpmovzxwd(masked_vmm, op);
uni_vpslld(vmm, vmm, 16);
break;
case data_type::f16: vcvtph2ps(masked_vmm, op); break;
default: assert(!"unsupported wei_scales data type");
}
}
// Common used masks to permute data
using opmask_t = const Xbyak::Opmask;
opmask_t k3333 = k1;
opmask_t k5555 = k2;
opmask_t kAAAA = k3;
opmask_t kCCCC = k4;
opmask_t k0F0F = k5;
opmask_t kF0F0 = k6;
opmask_t kFFFF = k6;
opmask_t kTail = k7;
opmask_t kTail_int4 = k5; // TODO: refactor: use kTail for both cases
};
template <typename Vmm> template <typename Vmm>
struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t, struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
public jit_generator_t { public jit_generator_t {
@ -2892,13 +3213,12 @@ void jit_brgemm_matmul_copy_b_int8_t<Vmm>::generate() {
} }
template <typename Vmm> template <typename Vmm>
struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t, struct jit_brgemm_matmul_copy_b_bf16_t
public jit_generator_t { : public jit_brgemm_matmul_copy_b_common_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t) DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t)
jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf) jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf)
: jit_brgemm_matmul_copy_b_t(conf) : jit_brgemm_matmul_copy_b_common_t(conf)
, jit_generator_t(jit_name())
, typesize(conf->b_dt_sz) , typesize(conf->b_dt_sz)
, tr_typesize(conf->tr_b_dt_sz) , tr_typesize(conf->tr_b_dt_sz)
, wei_scales_typesize(conf->wei_scales_dt_sz) , wei_scales_typesize(conf->wei_scales_dt_sz)
@ -2950,12 +3270,6 @@ private:
constexpr static int reg_current_K_pad_offs_ = 16; constexpr static int reg_current_K_pad_offs_ = 16;
constexpr static int stack_space_needed = 24; constexpr static int stack_space_needed = 24;
opmask_t kTail = k7;
opmask_t kFFFF = k6;
opmask_t kTail_int4 = k5;
opmask_t kAAAA = k4;
opmask_t k5555 = k3;
reg64_t reg_src = rax; reg64_t reg_src = rax;
reg64_t reg_tr_src = rbx; reg64_t reg_tr_src = rbx;
@ -2995,72 +3309,13 @@ private:
else else
jit_generator_t::kmovd(k, regw_tmp); jit_generator_t::kmovd(k, regw_tmp);
} }
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
assert(is_src_int4);
if (isa_has_masks(conf_->isa)) {
return is_tail ? vmm_lower | kTail_int4 | T_z
: vmm_lower | kFFFF | T_z;
} else {
return vmm_lower;
}
}
Vmm maybe_mask(Vmm vmm, bool is_tail) {
if (isa_has_masks(conf_->isa)) {
return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z;
} else {
return vmm;
}
}
void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail);
void copy_block(int nrows, int ncolumns, bool n_tail, bool zeropad); void copy_block(int nrows, int ncolumns, bool n_tail, bool zeropad);
void copy_2x32(int nrows, int ncolumns, bool zeropad); void copy_2x32(int nrows, int ncolumns, bool zeropad);
void init_masks(); void init_masks();
void generate() override; void generate() override;
}; };
template <typename Vmm>
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::load_data(
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) {
const auto vmm = maybe_mask(vmm_in, is_tail);
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
MAYBE_UNUSED(vmm_lower);
switch (conf_->orig_wei_dt) {
case data_type::f32: uni_vmovups(vmm, op); break;
case data_type::f16:
case data_type::bf16: vmovdqu16(vmm, op); break;
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | kAAAA, vmm_in, 4);
break;
case data_type::u4:
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
break;
default: assert(!"unsupported data type");
}
}
template <typename Vmm> template <typename Vmm>
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32( void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
int nrows, int ncolumns, bool zeropad) { int nrows, int ncolumns, bool zeropad) {
@ -3108,40 +3363,20 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
else else
uni_vmovups(src_load, load_addr); uni_vmovups(src_load, load_addr);
} else { } else {
load_data(src_reg, load_addr, is_tail); load_value(
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
} }
if (utils::one_of(conf_->orig_wei_dt, data_type::s8, data_type::u8, const auto scales_offset
data_type::s4, data_type::u4)) {
if (req_zp_b_shift) uni_vpsubd(src_load, src_load, vmm_zp_b_shift);
uni_vcvtdq2ps(src_load, src_load);
if (req_apply_wei_scales) {
const auto wei_scales_offset
= (is_dynamic_stride ? 0 : k * wei_scales_N_stride) = (is_dynamic_stride ? 0 : k * wei_scales_N_stride)
+ n * wei_scales_typesize; + n * wei_scales_typesize;
const auto wei_scales_addr = maybe_EVEX_compress_addr( const auto scales_addr
reg_wei_scales, wei_scales_offset); = maybe_EVEX_compress_addr(reg_wei_scales, scales_offset);
const auto vmm_wei_scales_masked if (req_apply_wei_scales)
= maybe_mask(vmm_wei_scales, is_tail); load_scale_value(
switch (conf_->wei_scales_dt) { vmm_wei_scales, scales_addr, conf_->wei_scales_dt, is_tail);
case data_type::f32: decompress_and_downcvt_reg(src_load, vmm_zp_b_shift, vmm_wei_scales,
uni_vmovups(vmm_wei_scales_masked, wei_scales_addr); conf_->orig_wei_dt, conf_->wei_dt);
break;
case data_type::bf16:
uni_vpmovzxwd(vmm_wei_scales_masked, wei_scales_addr);
uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16);
break;
case data_type::f16:
vcvtph2ps(vmm_wei_scales_masked, wei_scales_addr);
break;
default: assert(!"unsupported wei_scales data type");
}
uni_vmulps(src_load, src_load, vmm_wei_scales);
}
if (conf_->wei_dt == data_type::f16)
vcvtps2phx(Vmm_lower_t(src_reg.getIdx()), src_reg);
}
}; };
int iter = 0; int iter = 0;
@ -3185,15 +3420,12 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
if (nrows - k >= k_blk_step) { if (nrows - k >= k_blk_step) {
load(blk_idx, k + 1, n); load(blk_idx, k + 1, n);
if (req_cvtps2bf16) { if (is_superset(conf_->isa, avx512_core)) {
vcvtne2ps2bf16(src_vmm0, src_vmm1, src_vmm0);
} else if (is_superset(conf_->isa, avx512_core)) {
const auto src_ymm1 = ymm(src_vmm1.getIdx()); const auto src_ymm1 = ymm(src_vmm1.getIdx());
vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1); vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1);
} }
} else if (req_cvtps2bf16) { }
vcvtneps2bf16(ymm(src_vmm0.getIdx()), src_vmm0); if (!is_superset(conf_->isa, avx512_core)) {
} else if (!is_superset(conf_->isa, avx512_core)) {
uni_vxorps(src_vmm1, src_vmm1, src_vmm1); uni_vxorps(src_vmm1, src_vmm1, src_vmm1);
} }
@ -3774,13 +4006,11 @@ template struct jit_brgemm_matmul_copy_b_f32_t<Ymm>;
template <typename Vmm> template <typename Vmm>
struct jit_brgemm_matmul_copy_b_transposed_t struct jit_brgemm_matmul_copy_b_transposed_t
: public jit_brgemm_matmul_copy_b_t, : public jit_brgemm_matmul_copy_b_common_t {
public jit_generator_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t) DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t)
jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf) jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf)
: jit_brgemm_matmul_copy_b_t(conf) : jit_brgemm_matmul_copy_b_common_t(conf)
, jit_generator_t(jit_name())
, typesize_(conf_->b_dt_sz) , typesize_(conf_->b_dt_sz)
, tr_typesize_(conf_->tr_b_dt_sz) , tr_typesize_(conf_->tr_b_dt_sz)
, wei_scales_typesize_(conf_->wei_scales_dt_sz) , wei_scales_typesize_(conf_->wei_scales_dt_sz)
@ -3874,16 +4104,6 @@ private:
constexpr static int ldb_step_idx_offs = 0; constexpr static int ldb_step_idx_offs = 0;
constexpr static int stack_space_needed = 8; constexpr static int stack_space_needed = 8;
opmask_t k3333 = k1;
opmask_t k5555 = k2;
opmask_t kAAAA = k3;
opmask_t kCCCC = k4;
opmask_t k0F0F = k5;
opmask_t kF0F0 = k6;
opmask_t kTail = k7;
// reuse k7 for int4 and restore the value after use
opmask_t kTail_int4 = k7;
reg64_t reg_src_base = rax; reg64_t reg_src_base = rax;
reg64_t reg_tr_src_base = rbx; reg64_t reg_tr_src_base = rbx;
reg64_t reg_comp_ptr = rdx; reg64_t reg_comp_ptr = rdx;
@ -3943,31 +4163,9 @@ private:
return Vmm(n_blk_step_ + i); return Vmm(n_blk_step_ + i);
} }
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
assert(is_src_int4_);
return isa_has_masks(conf_->isa) && is_tail
? vmm_lower | kTail_int4 | T_z
: vmm_lower;
}
Vmm maybe_mask(Vmm vmm, bool is_tail) {
return isa_has_masks(conf_->isa) && is_tail ? vmm | kTail | T_z : vmm;
}
void init_tail_mask(const int columns_tail, const bool use_int4_mask); void init_tail_mask(const int columns_tail, const bool use_int4_mask);
void maybe_apply_wei_scales( bool preload_int4(const Xmm &xmm_in, const int i, const int columns_tail,
const Vmm vmm_in, const size_t offset, const bool is_tail); const bool is_tail, const dim_t offset);
void maybe_apply_zp_b_shift(const Vmm vmm_in, const bool is_tail);
void load_int(const Vmm vmm_in, const dim_t offset, const int i,
const int columns_tail, bool is_tail);
void copy_row_x_col(int nrows, int ncolumns); void copy_row_x_col(int nrows, int ncolumns);
void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter, void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter,
bool is_last_K_iter); bool is_last_K_iter);
@ -3989,23 +4187,25 @@ private:
return next_row_idx < num_rows || dynamic_tail; return next_row_idx < num_rows || dynamic_tail;
} }
void generate() override; /**
}; * Loads scales and broadcasts it over Vmm register.
* Supported data types: f32, bf16, f16.
*
* @param n N-dimension local index.
* @param is_tail Bool flag indicating if tail is processing.
*/
void load_scales(int n, bool is_tail) {
if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b)
return;
// This method applies scales for weights decompression scenario. Given it's the const auto offset = n * wei_scales_K_stride_;
// transposed kernel, B is in column-major format, but scales in the library are
// in row-major format.
template <typename Vmm>
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
const Vmm vmm_in, const size_t offset, const bool is_tail) {
if (!req_apply_wei_scales_) return;
if (wei_scales_K_stride_ == wei_scales_typesize_) { if (wei_scales_K_stride_ == wei_scales_typesize_) {
// A single scale per kernel case. // A single scale per kernel case.
// Enable broadcast address for f16 to avoid vmm manipulations. // Enable broadcast address for f16 to avoid vmm manipulations.
const auto wei_scales_addr = EVEX_compress_addr( const auto wei_scales_addr = EVEX_compress_addr(reg_wei_scales,
reg_wei_scales, offset, conf_->wei_scales_dt == data_type::f16); offset, conf_->wei_scales_dt == data_type::f16);
switch (conf_->wei_scales_dt) { switch (conf_->wei_scales_dt) {
case data_type::f32: case data_type::f32:
uni_vbroadcastss(vmm_wei_scales, wei_scales_addr); uni_vbroadcastss(vmm_wei_scales, wei_scales_addr);
@ -4019,9 +4219,6 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
break; break;
default: assert(!"unsupported wei_scales data type"); default: assert(!"unsupported wei_scales data type");
} }
const auto vmm = maybe_mask(vmm_in, is_tail);
vmulps(vmm, vmm_in, vmm_wei_scales);
} else { } else {
// A broadcasted ahead-of-time scales case. // A broadcasted ahead-of-time scales case.
@ -4033,8 +4230,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
// memory allocated is KxN, thus, there's an over-use of memory, but // memory allocated is KxN, thus, there's an over-use of memory, but
// such usage allows us to simplify the kernel logic and just load // such usage allows us to simplify the kernel logic and just load
// weights into a full vector register. // weights into a full vector register.
const auto wei_scales_addr = EVEX_compress_addr(reg_wei_scales, offset); const auto wei_scales_addr
const auto vmm_wei_scales_masked = maybe_mask(vmm_wei_scales, is_tail); = EVEX_compress_addr(reg_wei_scales, offset);
const auto vmm_wei_scales_masked
= maybe_mask(vmm_wei_scales, is_tail);
switch (conf_->wei_scales_dt) { switch (conf_->wei_scales_dt) {
case data_type::f32: case data_type::f32:
@ -4049,21 +4248,15 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
break; break;
default: assert(!"unsupported wei_scales data type"); default: assert(!"unsupported wei_scales data type");
} }
const auto vmm = maybe_mask(vmm_in, is_tail);
vmulps(vmm, vmm_in, vmm_wei_scales);
} }
} }
template <typename Vmm> void generate() override;
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_zp_b_shift( };
const Vmm vmm_in, const bool is_tail) {
if (!req_zp_b_shift_) return;
const auto vmm = maybe_mask(vmm_in, is_tail);
vpsubd(vmm, vmm, vmm_zp_b_val);
}
// This method applies scales for weights decompression scenario. Given it's the
// transposed kernel, B is in column-major format, but scales in the library are
// in row-major format.
template <typename Vmm> template <typename Vmm>
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask( void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
const int columns_tail, const bool use_int4_mask) { const int columns_tail, const bool use_int4_mask) {
@ -4083,23 +4276,32 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
} }
} }
/**
* @brief Handles special case when loading INT4 data with unaligned src_stride
*
* When processing INT4 data and (i * src_stride_) % 2 != 0, we need to perform additional operations
* to handle the unaligned half-byte boundaries:
*
* - For small loads (< 8 bytes or tail cases): We perform a simple right shift by 4 bits
* to eliminate the unnecessary half-byte at the front
*
* - For large loads (8 bytes): We need two registers to correctly reconstruct the data:
* 1. Shift the first register right by 4 bits to remove leading half-byte
* 2. Shift the second register left by 4 bits to preserve trailing half-byte
* 3. Combine both registers with logical OR to get the correctly aligned result
*
* @param xmm_in The target XMM register to store the aligned data
* @param i Current row index
* @param columns_tail Number of remaining columns in the tail case
* @param is_tail Flag indicating if processing a tail section
* @param offset Memory offset for loading from source
* @return True if INT4 preloading was performed, false otherwise
*/
template <typename Vmm> template <typename Vmm>
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::load_int(const Vmm vmm_in, bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
const dim_t offset, const int i, int columns_tail, bool is_tail) { const int i, const int columns_tail, const bool is_tail,
const auto vmm = maybe_mask(vmm_in, is_tail); const dim_t offset) {
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
const auto xmm_in = Xmm(vmm_in.getIdx());
const auto addr = EVEX_compress_addr(reg_src, offset); const auto addr = EVEX_compress_addr(reg_src, offset);
MAYBE_UNUSED(xmm_in);
MAYBE_UNUSED(vmm_lower);
if (is_src_int4_) init_tail_mask(columns_tail, true);
// Two additional operations are needed for int4 when i * src_stride_ % 2 != 0.
// The maximum data size for a bitwise shift is 8 bytes (quadwords).
// If the loaded data size is smaller than 8, we can directly perform a right
// shift to eliminate the unnecessary half-byte at the front.
// If the loaded data size is 8, we need two registers to handle the
// unnecessary half-byte at the front and back, respectively.
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0; const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
const auto max_shift_sz = 8; const auto max_shift_sz = 8;
if (need_preload_int4) { if (need_preload_int4) {
@ -4119,41 +4321,9 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::load_int(const Vmm vmm_in,
vpsllq(xmm_tmp, xmm_tmp, 4); vpsllq(xmm_tmp, xmm_tmp, 4);
vpord(xmm_in, xmm_in, xmm_tmp); vpord(xmm_in, xmm_in, xmm_tmp);
} }
return true;
} }
return false;
switch (conf_->orig_wei_dt) {
case data_type::s8: uni_vpmovsxbd(vmm, addr); break;
case data_type::u8: uni_vpmovzxbd(vmm, addr); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
if (need_preload_int4)
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), xmm_in);
else
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), addr);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | kAAAA, vmm_in, 4);
break;
case data_type::u4:
if (need_preload_int4)
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), xmm_in);
else
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), addr);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
break;
default: assert(!"unsupported data type");
}
// restore the tail_mask
if (is_src_int4_) init_tail_mask(columns_tail, false);
} }
template <typename Vmm> template <typename Vmm>
@ -4163,8 +4333,11 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
&& ncolumns <= k_blk_step_); && ncolumns <= k_blk_step_);
if (!nrows) return; if (!nrows) return;
const int columns_tail = ncolumns const auto cur_k_blk_step
% (req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_); = req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_;
const int columns_tail = ncolumns % cur_k_blk_step;
init_tail_mask(columns_tail, false); init_tail_mask(columns_tail, false);
auto load2bf16 = [this, nrows, columns_tail, ncolumns](int i) { auto load2bf16 = [this, nrows, columns_tail, ncolumns](int i) {
@ -4188,59 +4361,60 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
} }
// check if k_tail exists and it's in the first zmm // check if k_tail exists and it's in the first zmm
auto zmm_src = columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_ const auto is_tail
? src_reg | kTail | T_z = columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_;
: src_reg; auto src_reg_masked = maybe_mask(src_reg, is_tail);
const auto src_offset = (i * src_stride_) / src_elems_per_byte_; const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
const auto addr = EVEX_compress_addr(reg_src, src_offset); const auto addr = EVEX_compress_addr(reg_src, src_offset);
if (is_bf32_) if (is_bf32_)
vmovups(zmm_src, addr); vmovups(src_reg_masked, addr);
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) { else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
const bool is_tail const auto xmm_preload = Xmm(src_reg.getIdx());
= columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_; MAYBE_UNUSED(xmm_preload);
load_int(src_reg, src_offset, i, columns_tail, is_tail); const bool preloaded_int4 = preload_int4(
maybe_apply_zp_b_shift(src_reg, is_tail); xmm_preload, i, columns_tail, is_tail, src_offset);
vcvtdq2ps(zmm_src, zmm_src); const auto &src_op = preloaded_int4
maybe_apply_wei_scales(src_reg, i * wei_scales_K_stride_, is_tail); ? static_cast<const Xbyak::Operand &>(xmm_preload)
: static_cast<const Xbyak::Operand &>(addr);
if (is_src_int4_) init_tail_mask(columns_tail, true);
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
if (is_src_int4_) init_tail_mask(columns_tail, false);
load_scales(i, is_tail);
decompress_reg(src_reg_masked, vmm_zp_b_val, vmm_wei_scales,
conf_->orig_wei_dt);
} else } else
assert(!"Unsupported data type in loading"); assert(!"Unsupported data type in loading");
if (ncolumns <= req_cvt_bf16_k_blk_step_) { if (ncolumns <= req_cvt_bf16_k_blk_step_) {
vpxord(src_reg_next, src_reg_next, src_reg_next); vpxord(src_reg_next, src_reg_next, src_reg_next);
} else { } else {
auto zmm_src_next = columns_tail > 0 ? src_reg_next | kTail | T_z const auto is_tail = columns_tail > 0;
: src_reg_next; auto src_next_masked = maybe_mask(src_reg_next, is_tail);
const auto next_src_offset const auto next_src_offset
= (i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_) = (i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_)
/ src_elems_per_byte_; / src_elems_per_byte_;
const auto next_addr = EVEX_compress_addr(reg_src, next_src_offset); const auto next_addr = EVEX_compress_addr(reg_src, next_src_offset);
if (is_bf32_) if (is_bf32_)
vmovups(zmm_src_next, next_addr); vmovups(src_next_masked, next_addr);
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) { else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
const auto is_tail = columns_tail > 0; const auto xmm_preload = Xmm(src_reg_next.getIdx());
load_int(src_reg_next, next_src_offset, i, columns_tail, MAYBE_UNUSED(xmm_preload);
columns_tail > 0); const bool preloaded_int4 = preload_int4(
maybe_apply_zp_b_shift(src_reg_next, is_tail); xmm_preload, i, columns_tail, is_tail, src_offset);
vcvtdq2ps(zmm_src_next, zmm_src_next); const auto &src_op = preloaded_int4
maybe_apply_wei_scales(src_reg_next, ? static_cast<const Xbyak::Operand &>(xmm_preload)
i * wei_scales_K_stride_ : static_cast<const Xbyak::Operand &>(next_addr);
+ !single_wei_scales_value_ if (is_src_int4_) init_tail_mask(columns_tail, true);
* req_cvt_bf16_k_blk_step_ load_value(src_reg_next, src_op, vmm_permd, conf_->orig_wei_dt,
* wei_scales_typesize_,
is_tail); is_tail);
if (is_src_int4_) init_tail_mask(columns_tail, false);
load_scales(i, is_tail);
decompress_reg(src_next_masked, vmm_zp_b_val, vmm_wei_scales,
conf_->orig_wei_dt);
} else } else
assert(!"Unsupported data type in loading"); assert(!"Unsupported data type in loading");
} }
downconvert_to_dst_dt(src_reg, src_reg_next, conf_->wei_dt);
if (conf_->wei_dt == data_type::bf16) {
vcvtne2ps2bf16(src_reg, src_reg_next, src_reg);
} else {
const auto src_vmm_lower0 = Vmm_lower_t(src_reg.getIdx());
const auto src_vmm_lower1 = Vmm_lower_t(src_reg_next.getIdx());
vcvtps2phx(src_vmm_lower0, src_reg);
vcvtps2phx(src_vmm_lower1, src_reg_next);
vinsertf64x4(src_reg, src_reg, src_vmm_lower1, 1);
}
L(load_done); L(load_done);
}; };
@ -4268,10 +4442,20 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
const auto src_offset = (i * src_stride_) / src_elems_per_byte_; const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
const auto addr = EVEX_compress_addr(reg_src, src_offset); const auto addr = EVEX_compress_addr(reg_src, src_offset);
if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) { if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) {
load_int(src_reg, src_offset, i, columns_tail, is_tail); const auto xmm_preload = Xmm(src_reg.getIdx());
maybe_apply_zp_b_shift(src_reg, is_tail);
vcvtdq2ps(src_load, src_load); MAYBE_UNUSED(xmm_preload);
maybe_apply_wei_scales(src_reg, i * wei_scales_K_stride_, is_tail); const bool preloaded_int4 = preload_int4(
xmm_preload, i, columns_tail, is_tail, src_offset);
const auto &src_op = preloaded_int4
? static_cast<const Xbyak::Operand &>(xmm_preload)
: static_cast<const Xbyak::Operand &>(addr);
if (is_src_int4_) init_tail_mask(columns_tail, true);
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
if (is_src_int4_) init_tail_mask(columns_tail, false);
load_scales(i, is_tail);
decompress_reg(maybe_mask(src_reg, is_tail), vmm_zp_b_val,
vmm_wei_scales, conf_->orig_wei_dt);
} else if (use_fp16_instructions_) { } else if (use_fp16_instructions_) {
if (conf_->isa == avx512_core_fp16) { if (conf_->isa == avx512_core_fp16) {
vcvtph2psx(src_load, addr); vcvtph2psx(src_load, addr);
@ -4801,13 +4985,12 @@ template struct jit_brgemm_matmul_copy_b_transposed_t<Zmm>;
template struct jit_brgemm_matmul_copy_b_transposed_t<Ymm>; template struct jit_brgemm_matmul_copy_b_transposed_t<Ymm>;
template <typename Vmm> template <typename Vmm>
struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t, struct jit_brgemm_matmul_copy_b_cvt_bf16_t
public jit_generator_t { : public jit_brgemm_matmul_copy_b_common_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_cvt_bf16_t) DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_cvt_bf16_t)
jit_brgemm_matmul_copy_b_cvt_bf16_t(const brgemm_matmul_conf_t *conf) jit_brgemm_matmul_copy_b_cvt_bf16_t(const brgemm_matmul_conf_t *conf)
: jit_brgemm_matmul_copy_b_t(conf) : jit_brgemm_matmul_copy_b_common_t(conf)
, jit_generator_t(jit_name())
, typesize_(conf->b_dt_sz) , typesize_(conf->b_dt_sz)
, tr_typesize_(conf->tr_b_dt_sz) , tr_typesize_(conf->tr_b_dt_sz)
, wei_scales_typesize_(conf_->wei_scales_dt_sz) , wei_scales_typesize_(conf_->wei_scales_dt_sz)
@ -4852,11 +5035,6 @@ private:
const bool req_apply_wei_scales_; const bool req_apply_wei_scales_;
const int reserved_regs_; const int reserved_regs_;
opmask_t kTail = k7;
opmask_t kFFFF = k6;
opmask_t kAAAA = k5;
opmask_t k5555 = k4;
reg64_t reg_src = rax; reg64_t reg_src = rax;
reg64_t reg_tr_src = rbx; reg64_t reg_tr_src = rbx;
@ -4875,17 +5053,6 @@ private:
Vmm vmm_wei_scales1 = Vmm(3); Vmm vmm_wei_scales1 = Vmm(3);
Vmm vmm_tmp = Vmm(4); Vmm vmm_tmp = Vmm(4);
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
Vmm maybe_mask(Vmm vmm, bool is_tail) {
if (isa_has_masks(conf_->isa)) {
return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z;
} else {
return vmm;
}
}
Vmm get_vmm(const int blk, const int idx) { Vmm get_vmm(const int blk, const int idx) {
const int max_isa_regs = isa_num_vregs(conf_->isa); const int max_isa_regs = isa_num_vregs(conf_->isa);
const int max_unroll = (max_isa_regs - reserved_regs_) / k_blk_step; const int max_unroll = (max_isa_regs - reserved_regs_) / k_blk_step;
@ -4897,8 +5064,7 @@ private:
} }
void init_masks(); void init_masks();
void load_int(const Vmm vmm_in, const Xbyak::Operand &op); void get_scales(const int blk, const int k, const int n,
void get_wei_scales(const int blk, const int k, const int n,
const bool is_n_tail, const bool is_k_tail); const bool is_n_tail, const bool is_k_tail);
void copy_block(const int nrows, const int ncolumns, bool zeropad); void copy_block(const int nrows, const int ncolumns, bool zeropad);
void generate() override; void generate() override;
@ -4923,41 +5089,11 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::init_masks() {
} }
template <typename Vmm> template <typename Vmm>
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::load_int( void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_scales(const int blk,
const Vmm vmm_in, const Xbyak::Operand &op) {
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
MAYBE_UNUSED(vmm_lower);
switch (conf_->orig_wei_dt) {
case data_type::s8: uni_vpmovsxbd(vmm_in, op); break;
case data_type::u8: uni_vpmovzxbd(vmm_in, op); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
uni_vpmovsxbd(vmm_lower, op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | kAAAA, vmm_in, 4);
break;
case data_type::u4:
uni_vpmovzxbd(vmm_lower, op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
break;
default: assert(!"unsupported data type");
}
}
template <typename Vmm>
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_wei_scales(const int blk,
const int k, const int n, const bool is_n_tail, const bool is_k_tail) { const int k, const int n, const bool is_n_tail, const bool is_k_tail) {
if (!req_apply_wei_scales_) return;
const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail); const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail);
const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail); const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail);
const auto base_offset = [&](int k) { const auto base_offset = [&](int k) {
@ -5018,31 +5154,15 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::copy_block(
const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_; const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_;
auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset); auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset);
auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride); auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride);
load_int(src_vmm0, load_addr0); load_value(src_vmm0, load_addr0, vmm_permd, conf_->orig_wei_dt);
load_int(src_vmm1, load_addr1); load_value(src_vmm1, load_addr1, vmm_permd, conf_->orig_wei_dt);
if (req_zp_b_shift_) {
vpsubd(src_vmm0, src_vmm0, vmm_zp_b_val);
vpsubd(src_vmm1, src_vmm1, vmm_zp_b_val);
}
vcvtdq2ps(src_vmm0, src_vmm0);
vcvtdq2ps(src_vmm1, src_vmm1);
if (req_apply_wei_scales_) {
const bool is_n_tail = ncolumns - n < n_blk_step; const bool is_n_tail = ncolumns - n < n_blk_step;
const bool is_k_tail = nrows - k < k_blk_step; const bool is_k_tail = nrows - k < k_blk_step;
get_wei_scales(blk, k, n, is_n_tail, is_k_tail); get_scales(blk, k, n, is_n_tail, is_k_tail);
vmulps(src_vmm0, src_vmm0, vmm_wei_scales0); decompress_and_downcvt_2reg(src_vmm0, src_vmm1, vmm_zp_b_val,
vmulps(src_vmm1, src_vmm1, vmm_wei_scales1); vmm_zp_b_val, vmm_wei_scales0, vmm_wei_scales0,
} conf_->orig_wei_dt, conf_->wei_dt);
if (conf_->wei_dt == data_type::bf16) {
vcvtne2ps2bf16(src_vmm0, src_vmm1, src_vmm0);
} else {
const auto src_vmm_lower0 = Vmm_lower_t(src_vmm0.getIdx());
const auto src_vmm_lower1 = Vmm_lower_t(src_vmm1.getIdx());
vcvtps2phx(src_vmm_lower0, src_vmm0);
vcvtps2phx(src_vmm_lower1, src_vmm1);
vinsertf64x4(src_vmm0, src_vmm0, src_vmm_lower1, 1);
}
}; };
int iter = 0; int iter = 0;