mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
x64: matmul: Refactoring matmul copy kernels
This commit is contained in:
committed by
Dmitriy Ovchinnikov
parent
794caa04ba
commit
f92a20983c
@ -2063,6 +2063,327 @@ void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::generate() {
|
|||||||
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>;
|
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>;
|
||||||
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>;
|
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Common class for BRGEMM B matrix copy operations
|
||||||
|
*
|
||||||
|
* This class contains common methods and properties for all copy B kernels.
|
||||||
|
* Now it consists of `load_value` and `decompress_reg` and it's considered
|
||||||
|
* to contain all common methods for copy B kernels.
|
||||||
|
*/
|
||||||
|
struct jit_brgemm_matmul_copy_b_common_t : public jit_brgemm_matmul_copy_b_t,
|
||||||
|
public jit_generator_t {
|
||||||
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_common_t)
|
||||||
|
|
||||||
|
jit_brgemm_matmul_copy_b_common_t(const brgemm_matmul_conf_t *conf)
|
||||||
|
: jit_brgemm_matmul_copy_b_t(conf), jit_generator_t(jit_name()) {}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
/**
|
||||||
|
* @brief Conditionally applies a mask to a vector register for tail or specialized processing
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param vmm The vector register to potentially mask
|
||||||
|
* @param is_tail Flag indicating if this is tail processing that requires masking
|
||||||
|
* @param is_int4 Flag indicating if this is INT4 data requiring specialized masking
|
||||||
|
* @return The potentially masked vector register (original if no masking needed)
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
Vmm maybe_mask(Vmm vmm, bool is_tail, bool is_int4 = false) {
|
||||||
|
// Transposed kernel uses the same kTail mask for both cases
|
||||||
|
const auto tail_mask
|
||||||
|
= is_int4 && !conf_->transposed_B ? kTail_int4 : kTail;
|
||||||
|
const auto unmask_tail
|
||||||
|
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16)
|
||||||
|
&& !conf_->transposed_B;
|
||||||
|
if (isa_has_masks(conf_->isa))
|
||||||
|
return is_tail ? vmm | tail_mask | T_z
|
||||||
|
// bf16 and f16 requires masking for tail
|
||||||
|
// to avoid AVX512F issues with zeroing upper bits
|
||||||
|
// of ZMM registers when using vpmovzxwd/vpmovzxbd
|
||||||
|
// instructions
|
||||||
|
: unmask_tail ? vmm | kFFFF | T_z
|
||||||
|
: vmm;
|
||||||
|
else
|
||||||
|
return vmm;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Inserts a YMM register into the upper half of a ZMM register
|
||||||
|
* @param zmm Destination ZMM register where data will be inserted
|
||||||
|
* @param ymm_half Source YMM register containing data for the upper half
|
||||||
|
*/
|
||||||
|
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
||||||
|
vinserti64x4(zmm, zmm, ymm_half, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Inserts an XMM register into the upper half of a YMM register
|
||||||
|
* @param ymm Destination YMM register where data will be inserted
|
||||||
|
* @param xmm_half Source XMM register containing data for the upper half
|
||||||
|
*/
|
||||||
|
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
|
||||||
|
vinserti128(ymm, ymm, xmm_half, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Loads and converts data of various types into vector registers with appropriate handling
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type for computation
|
||||||
|
* @param reg Destination vector register
|
||||||
|
* @param op Source memory operand to load from
|
||||||
|
* @param vmm_permd Vector register containing permutation indices for INT4 processing
|
||||||
|
* @param dt Data type being loaded
|
||||||
|
* @param is_tail Flag indicating if tail processing is needed
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void load_value(const Vmm ®, const Xbyak::Operand &op,
|
||||||
|
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false) {
|
||||||
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||||
|
const auto vmm_in = maybe_mask(reg, is_tail);
|
||||||
|
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
|
||||||
|
MAYBE_UNUSED(vmm_lower);
|
||||||
|
|
||||||
|
const bool is_xf16
|
||||||
|
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16);
|
||||||
|
|
||||||
|
switch (dt) {
|
||||||
|
case data_type::s32: vmovdqu32(vmm_in, op); break;
|
||||||
|
case data_type::f32: {
|
||||||
|
if (conf_->transposed_B)
|
||||||
|
vmovdqu8(vmm_in, op);
|
||||||
|
else
|
||||||
|
uni_vmovups(vmm_in, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case data_type::f16:
|
||||||
|
if (!is_xf16) {
|
||||||
|
uni_vcvtph2psx(vmm_in, op);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case data_type::bf16:
|
||||||
|
if (is_xf16) {
|
||||||
|
vmovdqu16(vmm_in, op);
|
||||||
|
} else {
|
||||||
|
uni_vpmovzxwd(vmm_in, op);
|
||||||
|
uni_vpslld(vmm_in, vmm_in, 16);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case data_type::s8: uni_vpmovsxbd(vmm_in, op); break;
|
||||||
|
case data_type::u8: uni_vpmovzxbd(vmm_in, op); break;
|
||||||
|
// For int4, we see two int4 as one int8 and extend them int32
|
||||||
|
// low half stores in lower bytes of vmm and high half in higher
|
||||||
|
// bytes of vmm, then permute them into correct order
|
||||||
|
// Finally, we process the extend bytes for s4/u4 accordingly
|
||||||
|
case data_type::s4:
|
||||||
|
uni_vpmovsxbd(
|
||||||
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
||||||
|
op);
|
||||||
|
copy_half_int4(reg, vmm_lower);
|
||||||
|
|
||||||
|
vpermd(reg, vmm_permd, reg);
|
||||||
|
uni_vpslld(reg | k5555, reg, 28);
|
||||||
|
vpsrad(reg | k5555, reg, 28);
|
||||||
|
vpsrad(reg | kAAAA, reg, 4);
|
||||||
|
break;
|
||||||
|
case data_type::u4:
|
||||||
|
uni_vpmovzxbd(
|
||||||
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
||||||
|
op);
|
||||||
|
copy_half_int4(reg, vmm_lower);
|
||||||
|
|
||||||
|
vpermd(reg, vmm_permd, reg);
|
||||||
|
uni_vpslld(reg | k5555, reg, 28);
|
||||||
|
vpsrld(reg | k5555, reg, 28);
|
||||||
|
vpsrld(reg | kAAAA, reg, 4);
|
||||||
|
break;
|
||||||
|
default: assert(!"unsupported data type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Applies zero point shift to vector register
|
||||||
|
* Shifts input values by subtracting zero point values.
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param input Vector register containing input values to be shifted
|
||||||
|
* @param zp Vector register containing zero point values to be subtracted
|
||||||
|
* @param src_dt Source data type that determines which instruction to use
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void apply_shift(const Vmm &input, const Vmm &zp, data_type_t src_dt) {
|
||||||
|
if (!conf_->has_zero_point_b) return;
|
||||||
|
|
||||||
|
const bool is_int_shift = one_of(src_dt, data_type::s8, data_type::u8,
|
||||||
|
data_type::s4, data_type::u4);
|
||||||
|
const bool is_fp_shift = one_of(
|
||||||
|
src_dt, data_type::bf16, data_type::f16, data_type::f32);
|
||||||
|
|
||||||
|
if (is_int_shift)
|
||||||
|
uni_vpsubd(input, input, zp);
|
||||||
|
else if (is_fp_shift)
|
||||||
|
uni_vsubps(input, input, zp);
|
||||||
|
// should be unreachable
|
||||||
|
else
|
||||||
|
assert(!"Unable to shift input for zero-point");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Converts integer data types to floating-point format
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param input Vector register containing data to be converted
|
||||||
|
* @param src_dt Source data type of the values to be converted
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void upconvert_to_f32(const Vmm &input, data_type_t src_dt) {
|
||||||
|
switch (src_dt) {
|
||||||
|
case data_type::s8:
|
||||||
|
case data_type::u8:
|
||||||
|
case data_type::s4:
|
||||||
|
case data_type::u4: uni_vcvtdq2ps(input, input); break;
|
||||||
|
case data_type::bf16:
|
||||||
|
case data_type::f16:
|
||||||
|
case data_type::f32:
|
||||||
|
// bf16 and f16 already converted into f32 while loading
|
||||||
|
break;
|
||||||
|
default: assert(!"Unsupported source data type for decompression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Applies scale factors to input values when configured
|
||||||
|
* The operation is only performed if apply_scales_in_buffer_b flag is set
|
||||||
|
* in the configuration.
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param input Vector register containing values to be scaled
|
||||||
|
* @param scale_op Operand containing scale factors to apply
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void apply_scales(const Vmm &input, const Xbyak::Operand &scale_op) {
|
||||||
|
if (conf_->apply_scales_in_buffer_b) uni_vmulps(input, input, scale_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Converts floating-point data from F32 to specified destination data type
|
||||||
|
*
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param input Vector register containing F32 values to be converted
|
||||||
|
* @param dst_dt Destination data type for conversion
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void downconvert_to_dst_dt(const Vmm &input, data_type_t dst_dt) {
|
||||||
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||||
|
switch (dst_dt) {
|
||||||
|
case data_type::bf16:
|
||||||
|
vcvtneps2bf16(Ymm(input.getIdx()), input);
|
||||||
|
break;
|
||||||
|
case data_type::f16:
|
||||||
|
vcvtps2phx(Vmm_lower_t(input.getIdx()), input);
|
||||||
|
break;
|
||||||
|
case data_type::f32:
|
||||||
|
// f32 is already in the correct format
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(!"Unsupported destination data type for decompression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Vmm>
|
||||||
|
void downconvert_to_dst_dt(
|
||||||
|
const Vmm ®1, const Vmm ®2, data_type_t dst_dt) {
|
||||||
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||||
|
switch (dst_dt) {
|
||||||
|
case data_type::bf16: vcvtne2ps2bf16(reg1, reg2, reg1); break;
|
||||||
|
case data_type::f16: {
|
||||||
|
const auto src_vmm_lower0 = Vmm_lower_t(reg1.getIdx());
|
||||||
|
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
|
||||||
|
vcvtps2phx(src_vmm_lower0, reg1);
|
||||||
|
vcvtps2phx(src_vmm_lower1, reg2);
|
||||||
|
vinsertf64x4(reg1, reg1, src_vmm_lower1, 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case data_type::f32:
|
||||||
|
// f32 is already in the correct format
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
assert(!"Unsupported destination data type for decompression");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Decompresses values by applying zero point shift, data type conversion, and scaling.
|
||||||
|
* @tparam Vmm Vector register type used for computation
|
||||||
|
* @param input Vector register containing input values to be decompressed
|
||||||
|
* @param zp Vector register containing zero point values
|
||||||
|
* @param scale_op Operand containing scales to be applied
|
||||||
|
* @param src_dt Source data type of the values being decompressed
|
||||||
|
* @param dst_dt Destination data type for the decompressed values
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void decompress_reg(const Vmm &input, const Vmm &zp,
|
||||||
|
const Xbyak::Operand &scale_op, data_type_t src_dt) {
|
||||||
|
if (src_dt == data_type::f32)
|
||||||
|
return; // Decompression doesn't support f32
|
||||||
|
apply_shift(input, zp, src_dt);
|
||||||
|
upconvert_to_f32(input, src_dt);
|
||||||
|
apply_scales(input, scale_op);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Vmm>
|
||||||
|
void decompress_and_downcvt_reg(const Vmm &input, const Vmm &zp,
|
||||||
|
const Xbyak::Operand &scale_op, data_type_t src_dt,
|
||||||
|
data_type_t dst_dt) {
|
||||||
|
if (src_dt == dst_dt) return;
|
||||||
|
decompress_reg(input, zp, scale_op, src_dt);
|
||||||
|
downconvert_to_dst_dt(input, dst_dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename Vmm>
|
||||||
|
void decompress_and_downcvt_2reg(const Vmm &input1, const Vmm &input2,
|
||||||
|
const Vmm &zp1, const Vmm &zp2, const Xbyak::Operand &scale_op1,
|
||||||
|
const Xbyak::Operand &scale_op2, data_type_t src_dt,
|
||||||
|
data_type_t dst_dt) {
|
||||||
|
if (src_dt == dst_dt) return;
|
||||||
|
decompress_reg(input1, zp1, scale_op1, src_dt);
|
||||||
|
decompress_reg(input2, zp2, scale_op2, src_dt);
|
||||||
|
downconvert_to_dst_dt(input1, input2, dst_dt);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** @brief Helper method to load scales into vector register.
|
||||||
|
* Supports f32, bf16 and f16 data types.
|
||||||
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
||||||
|
* @param vmm Vector register to load scale value into
|
||||||
|
* @param op Operand to load scale value from
|
||||||
|
**/
|
||||||
|
template <typename Vmm>
|
||||||
|
void load_scale_value(const Vmm &vmm, const Xbyak::Operand &op,
|
||||||
|
data_type_t dt, bool is_tail = false) {
|
||||||
|
const auto masked_vmm = maybe_mask(vmm, is_tail);
|
||||||
|
switch (dt) {
|
||||||
|
case data_type::f32: uni_vmovups(masked_vmm, op); break;
|
||||||
|
case data_type::bf16:
|
||||||
|
uni_vpmovzxwd(masked_vmm, op);
|
||||||
|
uni_vpslld(vmm, vmm, 16);
|
||||||
|
break;
|
||||||
|
case data_type::f16: vcvtph2ps(masked_vmm, op); break;
|
||||||
|
default: assert(!"unsupported wei_scales data type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common used masks to permute data
|
||||||
|
using opmask_t = const Xbyak::Opmask;
|
||||||
|
opmask_t k3333 = k1;
|
||||||
|
opmask_t k5555 = k2;
|
||||||
|
opmask_t kAAAA = k3;
|
||||||
|
opmask_t kCCCC = k4;
|
||||||
|
opmask_t k0F0F = k5;
|
||||||
|
opmask_t kF0F0 = k6;
|
||||||
|
opmask_t kFFFF = k6;
|
||||||
|
opmask_t kTail = k7;
|
||||||
|
opmask_t kTail_int4 = k5; // TODO: refactor: use kTail for both cases
|
||||||
|
};
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
|
struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
|
||||||
public jit_generator_t {
|
public jit_generator_t {
|
||||||
@ -2892,13 +3213,12 @@ void jit_brgemm_matmul_copy_b_int8_t<Vmm>::generate() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
struct jit_brgemm_matmul_copy_b_bf16_t : public jit_brgemm_matmul_copy_b_t,
|
struct jit_brgemm_matmul_copy_b_bf16_t
|
||||||
public jit_generator_t {
|
: public jit_brgemm_matmul_copy_b_common_t {
|
||||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t)
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t)
|
||||||
|
|
||||||
jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf)
|
jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf)
|
||||||
: jit_brgemm_matmul_copy_b_t(conf)
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
||||||
, jit_generator_t(jit_name())
|
|
||||||
, typesize(conf->b_dt_sz)
|
, typesize(conf->b_dt_sz)
|
||||||
, tr_typesize(conf->tr_b_dt_sz)
|
, tr_typesize(conf->tr_b_dt_sz)
|
||||||
, wei_scales_typesize(conf->wei_scales_dt_sz)
|
, wei_scales_typesize(conf->wei_scales_dt_sz)
|
||||||
@ -2950,12 +3270,6 @@ private:
|
|||||||
constexpr static int reg_current_K_pad_offs_ = 16;
|
constexpr static int reg_current_K_pad_offs_ = 16;
|
||||||
constexpr static int stack_space_needed = 24;
|
constexpr static int stack_space_needed = 24;
|
||||||
|
|
||||||
opmask_t kTail = k7;
|
|
||||||
opmask_t kFFFF = k6;
|
|
||||||
opmask_t kTail_int4 = k5;
|
|
||||||
opmask_t kAAAA = k4;
|
|
||||||
opmask_t k5555 = k3;
|
|
||||||
|
|
||||||
reg64_t reg_src = rax;
|
reg64_t reg_src = rax;
|
||||||
reg64_t reg_tr_src = rbx;
|
reg64_t reg_tr_src = rbx;
|
||||||
|
|
||||||
@ -2995,72 +3309,13 @@ private:
|
|||||||
else
|
else
|
||||||
jit_generator_t::kmovd(k, regw_tmp);
|
jit_generator_t::kmovd(k, regw_tmp);
|
||||||
}
|
}
|
||||||
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
|
||||||
vinserti64x4(zmm, zmm, ymm_half, 1);
|
|
||||||
}
|
|
||||||
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
|
|
||||||
vinserti128(ymm, ymm, xmm_half, 1);
|
|
||||||
}
|
|
||||||
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
|
|
||||||
assert(is_src_int4);
|
|
||||||
if (isa_has_masks(conf_->isa)) {
|
|
||||||
return is_tail ? vmm_lower | kTail_int4 | T_z
|
|
||||||
: vmm_lower | kFFFF | T_z;
|
|
||||||
} else {
|
|
||||||
return vmm_lower;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Vmm maybe_mask(Vmm vmm, bool is_tail) {
|
|
||||||
if (isa_has_masks(conf_->isa)) {
|
|
||||||
return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z;
|
|
||||||
} else {
|
|
||||||
return vmm;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail);
|
|
||||||
void copy_block(int nrows, int ncolumns, bool n_tail, bool zeropad);
|
void copy_block(int nrows, int ncolumns, bool n_tail, bool zeropad);
|
||||||
void copy_2x32(int nrows, int ncolumns, bool zeropad);
|
void copy_2x32(int nrows, int ncolumns, bool zeropad);
|
||||||
void init_masks();
|
void init_masks();
|
||||||
void generate() override;
|
void generate() override;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Vmm>
|
|
||||||
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::load_data(
|
|
||||||
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) {
|
|
||||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
|
||||||
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
|
|
||||||
MAYBE_UNUSED(vmm_lower);
|
|
||||||
|
|
||||||
switch (conf_->orig_wei_dt) {
|
|
||||||
case data_type::f32: uni_vmovups(vmm, op); break;
|
|
||||||
case data_type::f16:
|
|
||||||
case data_type::bf16: vmovdqu16(vmm, op); break;
|
|
||||||
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
|
|
||||||
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
|
|
||||||
// For int4, we see two int4 as one int8 and extend them int32
|
|
||||||
// low half stores in lower bytes of vmm and high half in higher
|
|
||||||
// bytes of vmm, then permute them into correct order
|
|
||||||
// Finally, we process the extend bytes for s4/u4 accordingly
|
|
||||||
case data_type::s4:
|
|
||||||
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
case data_type::u4:
|
|
||||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
default: assert(!"unsupported data type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
||||||
int nrows, int ncolumns, bool zeropad) {
|
int nrows, int ncolumns, bool zeropad) {
|
||||||
@ -3108,40 +3363,20 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
|||||||
else
|
else
|
||||||
uni_vmovups(src_load, load_addr);
|
uni_vmovups(src_load, load_addr);
|
||||||
} else {
|
} else {
|
||||||
load_data(src_reg, load_addr, is_tail);
|
load_value(
|
||||||
|
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (utils::one_of(conf_->orig_wei_dt, data_type::s8, data_type::u8,
|
const auto scales_offset
|
||||||
data_type::s4, data_type::u4)) {
|
|
||||||
if (req_zp_b_shift) uni_vpsubd(src_load, src_load, vmm_zp_b_shift);
|
|
||||||
uni_vcvtdq2ps(src_load, src_load);
|
|
||||||
if (req_apply_wei_scales) {
|
|
||||||
const auto wei_scales_offset
|
|
||||||
= (is_dynamic_stride ? 0 : k * wei_scales_N_stride)
|
= (is_dynamic_stride ? 0 : k * wei_scales_N_stride)
|
||||||
+ n * wei_scales_typesize;
|
+ n * wei_scales_typesize;
|
||||||
const auto wei_scales_addr = maybe_EVEX_compress_addr(
|
const auto scales_addr
|
||||||
reg_wei_scales, wei_scales_offset);
|
= maybe_EVEX_compress_addr(reg_wei_scales, scales_offset);
|
||||||
const auto vmm_wei_scales_masked
|
if (req_apply_wei_scales)
|
||||||
= maybe_mask(vmm_wei_scales, is_tail);
|
load_scale_value(
|
||||||
switch (conf_->wei_scales_dt) {
|
vmm_wei_scales, scales_addr, conf_->wei_scales_dt, is_tail);
|
||||||
case data_type::f32:
|
decompress_and_downcvt_reg(src_load, vmm_zp_b_shift, vmm_wei_scales,
|
||||||
uni_vmovups(vmm_wei_scales_masked, wei_scales_addr);
|
conf_->orig_wei_dt, conf_->wei_dt);
|
||||||
break;
|
|
||||||
case data_type::bf16:
|
|
||||||
uni_vpmovzxwd(vmm_wei_scales_masked, wei_scales_addr);
|
|
||||||
uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16);
|
|
||||||
break;
|
|
||||||
case data_type::f16:
|
|
||||||
vcvtph2ps(vmm_wei_scales_masked, wei_scales_addr);
|
|
||||||
break;
|
|
||||||
default: assert(!"unsupported wei_scales data type");
|
|
||||||
}
|
|
||||||
uni_vmulps(src_load, src_load, vmm_wei_scales);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (conf_->wei_dt == data_type::f16)
|
|
||||||
vcvtps2phx(Vmm_lower_t(src_reg.getIdx()), src_reg);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int iter = 0;
|
int iter = 0;
|
||||||
@ -3185,15 +3420,12 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
|||||||
|
|
||||||
if (nrows - k >= k_blk_step) {
|
if (nrows - k >= k_blk_step) {
|
||||||
load(blk_idx, k + 1, n);
|
load(blk_idx, k + 1, n);
|
||||||
if (req_cvtps2bf16) {
|
if (is_superset(conf_->isa, avx512_core)) {
|
||||||
vcvtne2ps2bf16(src_vmm0, src_vmm1, src_vmm0);
|
|
||||||
} else if (is_superset(conf_->isa, avx512_core)) {
|
|
||||||
const auto src_ymm1 = ymm(src_vmm1.getIdx());
|
const auto src_ymm1 = ymm(src_vmm1.getIdx());
|
||||||
vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1);
|
vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1);
|
||||||
}
|
}
|
||||||
} else if (req_cvtps2bf16) {
|
}
|
||||||
vcvtneps2bf16(ymm(src_vmm0.getIdx()), src_vmm0);
|
if (!is_superset(conf_->isa, avx512_core)) {
|
||||||
} else if (!is_superset(conf_->isa, avx512_core)) {
|
|
||||||
uni_vxorps(src_vmm1, src_vmm1, src_vmm1);
|
uni_vxorps(src_vmm1, src_vmm1, src_vmm1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3774,13 +4006,11 @@ template struct jit_brgemm_matmul_copy_b_f32_t<Ymm>;
|
|||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
struct jit_brgemm_matmul_copy_b_transposed_t
|
struct jit_brgemm_matmul_copy_b_transposed_t
|
||||||
: public jit_brgemm_matmul_copy_b_t,
|
: public jit_brgemm_matmul_copy_b_common_t {
|
||||||
public jit_generator_t {
|
|
||||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t)
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t)
|
||||||
|
|
||||||
jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf)
|
jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf)
|
||||||
: jit_brgemm_matmul_copy_b_t(conf)
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
||||||
, jit_generator_t(jit_name())
|
|
||||||
, typesize_(conf_->b_dt_sz)
|
, typesize_(conf_->b_dt_sz)
|
||||||
, tr_typesize_(conf_->tr_b_dt_sz)
|
, tr_typesize_(conf_->tr_b_dt_sz)
|
||||||
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
||||||
@ -3874,16 +4104,6 @@ private:
|
|||||||
constexpr static int ldb_step_idx_offs = 0;
|
constexpr static int ldb_step_idx_offs = 0;
|
||||||
constexpr static int stack_space_needed = 8;
|
constexpr static int stack_space_needed = 8;
|
||||||
|
|
||||||
opmask_t k3333 = k1;
|
|
||||||
opmask_t k5555 = k2;
|
|
||||||
opmask_t kAAAA = k3;
|
|
||||||
opmask_t kCCCC = k4;
|
|
||||||
opmask_t k0F0F = k5;
|
|
||||||
opmask_t kF0F0 = k6;
|
|
||||||
opmask_t kTail = k7;
|
|
||||||
// reuse k7 for int4 and restore the value after use
|
|
||||||
opmask_t kTail_int4 = k7;
|
|
||||||
|
|
||||||
reg64_t reg_src_base = rax;
|
reg64_t reg_src_base = rax;
|
||||||
reg64_t reg_tr_src_base = rbx;
|
reg64_t reg_tr_src_base = rbx;
|
||||||
reg64_t reg_comp_ptr = rdx;
|
reg64_t reg_comp_ptr = rdx;
|
||||||
@ -3943,31 +4163,9 @@ private:
|
|||||||
return Vmm(n_blk_step_ + i);
|
return Vmm(n_blk_step_ + i);
|
||||||
}
|
}
|
||||||
|
|
||||||
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
|
||||||
vinserti64x4(zmm, zmm, ymm_half, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
|
|
||||||
vinserti128(ymm, ymm, xmm_half, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
|
|
||||||
assert(is_src_int4_);
|
|
||||||
return isa_has_masks(conf_->isa) && is_tail
|
|
||||||
? vmm_lower | kTail_int4 | T_z
|
|
||||||
: vmm_lower;
|
|
||||||
}
|
|
||||||
|
|
||||||
Vmm maybe_mask(Vmm vmm, bool is_tail) {
|
|
||||||
return isa_has_masks(conf_->isa) && is_tail ? vmm | kTail | T_z : vmm;
|
|
||||||
}
|
|
||||||
|
|
||||||
void init_tail_mask(const int columns_tail, const bool use_int4_mask);
|
void init_tail_mask(const int columns_tail, const bool use_int4_mask);
|
||||||
void maybe_apply_wei_scales(
|
bool preload_int4(const Xmm &xmm_in, const int i, const int columns_tail,
|
||||||
const Vmm vmm_in, const size_t offset, const bool is_tail);
|
const bool is_tail, const dim_t offset);
|
||||||
void maybe_apply_zp_b_shift(const Vmm vmm_in, const bool is_tail);
|
|
||||||
void load_int(const Vmm vmm_in, const dim_t offset, const int i,
|
|
||||||
const int columns_tail, bool is_tail);
|
|
||||||
void copy_row_x_col(int nrows, int ncolumns);
|
void copy_row_x_col(int nrows, int ncolumns);
|
||||||
void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter,
|
void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter,
|
||||||
bool is_last_K_iter);
|
bool is_last_K_iter);
|
||||||
@ -3989,23 +4187,25 @@ private:
|
|||||||
return next_row_idx < num_rows || dynamic_tail;
|
return next_row_idx < num_rows || dynamic_tail;
|
||||||
}
|
}
|
||||||
|
|
||||||
void generate() override;
|
/**
|
||||||
};
|
* Loads scales and broadcasts it over Vmm register.
|
||||||
|
* Supported data types: f32, bf16, f16.
|
||||||
|
*
|
||||||
|
* @param n N-dimension local index.
|
||||||
|
* @param is_tail Bool flag indicating if tail is processing.
|
||||||
|
*/
|
||||||
|
void load_scales(int n, bool is_tail) {
|
||||||
|
if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b)
|
||||||
|
return;
|
||||||
|
|
||||||
// This method applies scales for weights decompression scenario. Given it's the
|
const auto offset = n * wei_scales_K_stride_;
|
||||||
// transposed kernel, B is in column-major format, but scales in the library are
|
|
||||||
// in row-major format.
|
|
||||||
template <typename Vmm>
|
|
||||||
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
|
|
||||||
const Vmm vmm_in, const size_t offset, const bool is_tail) {
|
|
||||||
if (!req_apply_wei_scales_) return;
|
|
||||||
|
|
||||||
if (wei_scales_K_stride_ == wei_scales_typesize_) {
|
if (wei_scales_K_stride_ == wei_scales_typesize_) {
|
||||||
// A single scale per kernel case.
|
// A single scale per kernel case.
|
||||||
|
|
||||||
// Enable broadcast address for f16 to avoid vmm manipulations.
|
// Enable broadcast address for f16 to avoid vmm manipulations.
|
||||||
const auto wei_scales_addr = EVEX_compress_addr(
|
const auto wei_scales_addr = EVEX_compress_addr(reg_wei_scales,
|
||||||
reg_wei_scales, offset, conf_->wei_scales_dt == data_type::f16);
|
offset, conf_->wei_scales_dt == data_type::f16);
|
||||||
switch (conf_->wei_scales_dt) {
|
switch (conf_->wei_scales_dt) {
|
||||||
case data_type::f32:
|
case data_type::f32:
|
||||||
uni_vbroadcastss(vmm_wei_scales, wei_scales_addr);
|
uni_vbroadcastss(vmm_wei_scales, wei_scales_addr);
|
||||||
@ -4019,9 +4219,6 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
|
|||||||
break;
|
break;
|
||||||
default: assert(!"unsupported wei_scales data type");
|
default: assert(!"unsupported wei_scales data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
|
||||||
vmulps(vmm, vmm_in, vmm_wei_scales);
|
|
||||||
} else {
|
} else {
|
||||||
// A broadcasted ahead-of-time scales case.
|
// A broadcasted ahead-of-time scales case.
|
||||||
|
|
||||||
@ -4033,8 +4230,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
|
|||||||
// memory allocated is KxN, thus, there's an over-use of memory, but
|
// memory allocated is KxN, thus, there's an over-use of memory, but
|
||||||
// such usage allows us to simplify the kernel logic and just load
|
// such usage allows us to simplify the kernel logic and just load
|
||||||
// weights into a full vector register.
|
// weights into a full vector register.
|
||||||
const auto wei_scales_addr = EVEX_compress_addr(reg_wei_scales, offset);
|
const auto wei_scales_addr
|
||||||
const auto vmm_wei_scales_masked = maybe_mask(vmm_wei_scales, is_tail);
|
= EVEX_compress_addr(reg_wei_scales, offset);
|
||||||
|
const auto vmm_wei_scales_masked
|
||||||
|
= maybe_mask(vmm_wei_scales, is_tail);
|
||||||
|
|
||||||
switch (conf_->wei_scales_dt) {
|
switch (conf_->wei_scales_dt) {
|
||||||
case data_type::f32:
|
case data_type::f32:
|
||||||
@ -4049,21 +4248,15 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_wei_scales(
|
|||||||
break;
|
break;
|
||||||
default: assert(!"unsupported wei_scales data type");
|
default: assert(!"unsupported wei_scales data type");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
|
||||||
vmulps(vmm, vmm_in, vmm_wei_scales);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Vmm>
|
void generate() override;
|
||||||
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::maybe_apply_zp_b_shift(
|
};
|
||||||
const Vmm vmm_in, const bool is_tail) {
|
|
||||||
if (!req_zp_b_shift_) return;
|
|
||||||
|
|
||||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
|
||||||
vpsubd(vmm, vmm, vmm_zp_b_val);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
// This method applies scales for weights decompression scenario. Given it's the
|
||||||
|
// transposed kernel, B is in column-major format, but scales in the library are
|
||||||
|
// in row-major format.
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
||||||
const int columns_tail, const bool use_int4_mask) {
|
const int columns_tail, const bool use_int4_mask) {
|
||||||
@ -4083,23 +4276,32 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Handles special case when loading INT4 data with unaligned src_stride
|
||||||
|
*
|
||||||
|
* When processing INT4 data and (i * src_stride_) % 2 != 0, we need to perform additional operations
|
||||||
|
* to handle the unaligned half-byte boundaries:
|
||||||
|
*
|
||||||
|
* - For small loads (< 8 bytes or tail cases): We perform a simple right shift by 4 bits
|
||||||
|
* to eliminate the unnecessary half-byte at the front
|
||||||
|
*
|
||||||
|
* - For large loads (8 bytes): We need two registers to correctly reconstruct the data:
|
||||||
|
* 1. Shift the first register right by 4 bits to remove leading half-byte
|
||||||
|
* 2. Shift the second register left by 4 bits to preserve trailing half-byte
|
||||||
|
* 3. Combine both registers with logical OR to get the correctly aligned result
|
||||||
|
*
|
||||||
|
* @param xmm_in The target XMM register to store the aligned data
|
||||||
|
* @param i Current row index
|
||||||
|
* @param columns_tail Number of remaining columns in the tail case
|
||||||
|
* @param is_tail Flag indicating if processing a tail section
|
||||||
|
* @param offset Memory offset for loading from source
|
||||||
|
* @return True if INT4 preloading was performed, false otherwise
|
||||||
|
*/
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::load_int(const Vmm vmm_in,
|
bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
|
||||||
const dim_t offset, const int i, int columns_tail, bool is_tail) {
|
const int i, const int columns_tail, const bool is_tail,
|
||||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
const dim_t offset) {
|
||||||
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
|
|
||||||
const auto xmm_in = Xmm(vmm_in.getIdx());
|
|
||||||
const auto addr = EVEX_compress_addr(reg_src, offset);
|
const auto addr = EVEX_compress_addr(reg_src, offset);
|
||||||
MAYBE_UNUSED(xmm_in);
|
|
||||||
MAYBE_UNUSED(vmm_lower);
|
|
||||||
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
|
||||||
|
|
||||||
// Two additional operations are needed for int4 when i * src_stride_ % 2 != 0.
|
|
||||||
// The maximum data size for a bitwise shift is 8 bytes (quadwords).
|
|
||||||
// If the loaded data size is smaller than 8, we can directly perform a right
|
|
||||||
// shift to eliminate the unnecessary half-byte at the front.
|
|
||||||
// If the loaded data size is 8, we need two registers to handle the
|
|
||||||
// unnecessary half-byte at the front and back, respectively.
|
|
||||||
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
|
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
|
||||||
const auto max_shift_sz = 8;
|
const auto max_shift_sz = 8;
|
||||||
if (need_preload_int4) {
|
if (need_preload_int4) {
|
||||||
@ -4119,41 +4321,9 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::load_int(const Vmm vmm_in,
|
|||||||
vpsllq(xmm_tmp, xmm_tmp, 4);
|
vpsllq(xmm_tmp, xmm_tmp, 4);
|
||||||
vpord(xmm_in, xmm_in, xmm_tmp);
|
vpord(xmm_in, xmm_in, xmm_tmp);
|
||||||
}
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
return false;
|
||||||
switch (conf_->orig_wei_dt) {
|
|
||||||
case data_type::s8: uni_vpmovsxbd(vmm, addr); break;
|
|
||||||
case data_type::u8: uni_vpmovzxbd(vmm, addr); break;
|
|
||||||
// For int4, we see two int4 as one int8 and extend them int32
|
|
||||||
// low half stores in lower bytes of vmm and high half in higher
|
|
||||||
// bytes of vmm, then permute them into correct order
|
|
||||||
// Finally, we process the extend bytes for s4/u4 accordingly
|
|
||||||
case data_type::s4:
|
|
||||||
if (need_preload_int4)
|
|
||||||
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), xmm_in);
|
|
||||||
else
|
|
||||||
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), addr);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
case data_type::u4:
|
|
||||||
if (need_preload_int4)
|
|
||||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), xmm_in);
|
|
||||||
else
|
|
||||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), addr);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
default: assert(!"unsupported data type");
|
|
||||||
}
|
|
||||||
// restore the tail_mask
|
|
||||||
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
@ -4163,8 +4333,11 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
&& ncolumns <= k_blk_step_);
|
&& ncolumns <= k_blk_step_);
|
||||||
if (!nrows) return;
|
if (!nrows) return;
|
||||||
|
|
||||||
const int columns_tail = ncolumns
|
const auto cur_k_blk_step
|
||||||
% (req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_);
|
= req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_;
|
||||||
|
|
||||||
|
const int columns_tail = ncolumns % cur_k_blk_step;
|
||||||
|
|
||||||
init_tail_mask(columns_tail, false);
|
init_tail_mask(columns_tail, false);
|
||||||
|
|
||||||
auto load2bf16 = [this, nrows, columns_tail, ncolumns](int i) {
|
auto load2bf16 = [this, nrows, columns_tail, ncolumns](int i) {
|
||||||
@ -4188,59 +4361,60 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// check if k_tail exists and it's in the first zmm
|
// check if k_tail exists and it's in the first zmm
|
||||||
auto zmm_src = columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_
|
const auto is_tail
|
||||||
? src_reg | kTail | T_z
|
= columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_;
|
||||||
: src_reg;
|
auto src_reg_masked = maybe_mask(src_reg, is_tail);
|
||||||
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
||||||
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
||||||
if (is_bf32_)
|
if (is_bf32_)
|
||||||
vmovups(zmm_src, addr);
|
vmovups(src_reg_masked, addr);
|
||||||
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
||||||
const bool is_tail
|
const auto xmm_preload = Xmm(src_reg.getIdx());
|
||||||
= columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_;
|
MAYBE_UNUSED(xmm_preload);
|
||||||
load_int(src_reg, src_offset, i, columns_tail, is_tail);
|
const bool preloaded_int4 = preload_int4(
|
||||||
maybe_apply_zp_b_shift(src_reg, is_tail);
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
||||||
vcvtdq2ps(zmm_src, zmm_src);
|
const auto &src_op = preloaded_int4
|
||||||
maybe_apply_wei_scales(src_reg, i * wei_scales_K_stride_, is_tail);
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
||||||
|
: static_cast<const Xbyak::Operand &>(addr);
|
||||||
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
||||||
|
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||||
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
||||||
|
load_scales(i, is_tail);
|
||||||
|
decompress_reg(src_reg_masked, vmm_zp_b_val, vmm_wei_scales,
|
||||||
|
conf_->orig_wei_dt);
|
||||||
} else
|
} else
|
||||||
assert(!"Unsupported data type in loading");
|
assert(!"Unsupported data type in loading");
|
||||||
|
|
||||||
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
|
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
|
||||||
vpxord(src_reg_next, src_reg_next, src_reg_next);
|
vpxord(src_reg_next, src_reg_next, src_reg_next);
|
||||||
} else {
|
} else {
|
||||||
auto zmm_src_next = columns_tail > 0 ? src_reg_next | kTail | T_z
|
const auto is_tail = columns_tail > 0;
|
||||||
: src_reg_next;
|
auto src_next_masked = maybe_mask(src_reg_next, is_tail);
|
||||||
const auto next_src_offset
|
const auto next_src_offset
|
||||||
= (i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_)
|
= (i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_)
|
||||||
/ src_elems_per_byte_;
|
/ src_elems_per_byte_;
|
||||||
const auto next_addr = EVEX_compress_addr(reg_src, next_src_offset);
|
const auto next_addr = EVEX_compress_addr(reg_src, next_src_offset);
|
||||||
if (is_bf32_)
|
if (is_bf32_)
|
||||||
vmovups(zmm_src_next, next_addr);
|
vmovups(src_next_masked, next_addr);
|
||||||
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
||||||
const auto is_tail = columns_tail > 0;
|
const auto xmm_preload = Xmm(src_reg_next.getIdx());
|
||||||
load_int(src_reg_next, next_src_offset, i, columns_tail,
|
MAYBE_UNUSED(xmm_preload);
|
||||||
columns_tail > 0);
|
const bool preloaded_int4 = preload_int4(
|
||||||
maybe_apply_zp_b_shift(src_reg_next, is_tail);
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
||||||
vcvtdq2ps(zmm_src_next, zmm_src_next);
|
const auto &src_op = preloaded_int4
|
||||||
maybe_apply_wei_scales(src_reg_next,
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
||||||
i * wei_scales_K_stride_
|
: static_cast<const Xbyak::Operand &>(next_addr);
|
||||||
+ !single_wei_scales_value_
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
||||||
* req_cvt_bf16_k_blk_step_
|
load_value(src_reg_next, src_op, vmm_permd, conf_->orig_wei_dt,
|
||||||
* wei_scales_typesize_,
|
|
||||||
is_tail);
|
is_tail);
|
||||||
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
||||||
|
load_scales(i, is_tail);
|
||||||
|
decompress_reg(src_next_masked, vmm_zp_b_val, vmm_wei_scales,
|
||||||
|
conf_->orig_wei_dt);
|
||||||
} else
|
} else
|
||||||
assert(!"Unsupported data type in loading");
|
assert(!"Unsupported data type in loading");
|
||||||
}
|
}
|
||||||
|
downconvert_to_dst_dt(src_reg, src_reg_next, conf_->wei_dt);
|
||||||
if (conf_->wei_dt == data_type::bf16) {
|
|
||||||
vcvtne2ps2bf16(src_reg, src_reg_next, src_reg);
|
|
||||||
} else {
|
|
||||||
const auto src_vmm_lower0 = Vmm_lower_t(src_reg.getIdx());
|
|
||||||
const auto src_vmm_lower1 = Vmm_lower_t(src_reg_next.getIdx());
|
|
||||||
vcvtps2phx(src_vmm_lower0, src_reg);
|
|
||||||
vcvtps2phx(src_vmm_lower1, src_reg_next);
|
|
||||||
vinsertf64x4(src_reg, src_reg, src_vmm_lower1, 1);
|
|
||||||
}
|
|
||||||
L(load_done);
|
L(load_done);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -4268,10 +4442,20 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
||||||
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
||||||
if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) {
|
if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) {
|
||||||
load_int(src_reg, src_offset, i, columns_tail, is_tail);
|
const auto xmm_preload = Xmm(src_reg.getIdx());
|
||||||
maybe_apply_zp_b_shift(src_reg, is_tail);
|
|
||||||
vcvtdq2ps(src_load, src_load);
|
MAYBE_UNUSED(xmm_preload);
|
||||||
maybe_apply_wei_scales(src_reg, i * wei_scales_K_stride_, is_tail);
|
const bool preloaded_int4 = preload_int4(
|
||||||
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
||||||
|
const auto &src_op = preloaded_int4
|
||||||
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
||||||
|
: static_cast<const Xbyak::Operand &>(addr);
|
||||||
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
||||||
|
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||||
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
||||||
|
load_scales(i, is_tail);
|
||||||
|
decompress_reg(maybe_mask(src_reg, is_tail), vmm_zp_b_val,
|
||||||
|
vmm_wei_scales, conf_->orig_wei_dt);
|
||||||
} else if (use_fp16_instructions_) {
|
} else if (use_fp16_instructions_) {
|
||||||
if (conf_->isa == avx512_core_fp16) {
|
if (conf_->isa == avx512_core_fp16) {
|
||||||
vcvtph2psx(src_load, addr);
|
vcvtph2psx(src_load, addr);
|
||||||
@ -4801,13 +4985,12 @@ template struct jit_brgemm_matmul_copy_b_transposed_t<Zmm>;
|
|||||||
template struct jit_brgemm_matmul_copy_b_transposed_t<Ymm>;
|
template struct jit_brgemm_matmul_copy_b_transposed_t<Ymm>;
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
struct jit_brgemm_matmul_copy_b_cvt_bf16_t : public jit_brgemm_matmul_copy_b_t,
|
struct jit_brgemm_matmul_copy_b_cvt_bf16_t
|
||||||
public jit_generator_t {
|
: public jit_brgemm_matmul_copy_b_common_t {
|
||||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_cvt_bf16_t)
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_cvt_bf16_t)
|
||||||
|
|
||||||
jit_brgemm_matmul_copy_b_cvt_bf16_t(const brgemm_matmul_conf_t *conf)
|
jit_brgemm_matmul_copy_b_cvt_bf16_t(const brgemm_matmul_conf_t *conf)
|
||||||
: jit_brgemm_matmul_copy_b_t(conf)
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
||||||
, jit_generator_t(jit_name())
|
|
||||||
, typesize_(conf->b_dt_sz)
|
, typesize_(conf->b_dt_sz)
|
||||||
, tr_typesize_(conf->tr_b_dt_sz)
|
, tr_typesize_(conf->tr_b_dt_sz)
|
||||||
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
||||||
@ -4852,11 +5035,6 @@ private:
|
|||||||
const bool req_apply_wei_scales_;
|
const bool req_apply_wei_scales_;
|
||||||
const int reserved_regs_;
|
const int reserved_regs_;
|
||||||
|
|
||||||
opmask_t kTail = k7;
|
|
||||||
opmask_t kFFFF = k6;
|
|
||||||
opmask_t kAAAA = k5;
|
|
||||||
opmask_t k5555 = k4;
|
|
||||||
|
|
||||||
reg64_t reg_src = rax;
|
reg64_t reg_src = rax;
|
||||||
reg64_t reg_tr_src = rbx;
|
reg64_t reg_tr_src = rbx;
|
||||||
|
|
||||||
@ -4875,17 +5053,6 @@ private:
|
|||||||
Vmm vmm_wei_scales1 = Vmm(3);
|
Vmm vmm_wei_scales1 = Vmm(3);
|
||||||
Vmm vmm_tmp = Vmm(4);
|
Vmm vmm_tmp = Vmm(4);
|
||||||
|
|
||||||
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
|
||||||
vinserti64x4(zmm, zmm, ymm_half, 1);
|
|
||||||
}
|
|
||||||
Vmm maybe_mask(Vmm vmm, bool is_tail) {
|
|
||||||
if (isa_has_masks(conf_->isa)) {
|
|
||||||
return is_tail ? vmm | kTail | T_z : vmm | kFFFF | T_z;
|
|
||||||
} else {
|
|
||||||
return vmm;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Vmm get_vmm(const int blk, const int idx) {
|
Vmm get_vmm(const int blk, const int idx) {
|
||||||
const int max_isa_regs = isa_num_vregs(conf_->isa);
|
const int max_isa_regs = isa_num_vregs(conf_->isa);
|
||||||
const int max_unroll = (max_isa_regs - reserved_regs_) / k_blk_step;
|
const int max_unroll = (max_isa_regs - reserved_regs_) / k_blk_step;
|
||||||
@ -4897,8 +5064,7 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
void init_masks();
|
void init_masks();
|
||||||
void load_int(const Vmm vmm_in, const Xbyak::Operand &op);
|
void get_scales(const int blk, const int k, const int n,
|
||||||
void get_wei_scales(const int blk, const int k, const int n,
|
|
||||||
const bool is_n_tail, const bool is_k_tail);
|
const bool is_n_tail, const bool is_k_tail);
|
||||||
void copy_block(const int nrows, const int ncolumns, bool zeropad);
|
void copy_block(const int nrows, const int ncolumns, bool zeropad);
|
||||||
void generate() override;
|
void generate() override;
|
||||||
@ -4923,41 +5089,11 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::init_masks() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::load_int(
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_scales(const int blk,
|
||||||
const Vmm vmm_in, const Xbyak::Operand &op) {
|
|
||||||
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
|
|
||||||
MAYBE_UNUSED(vmm_lower);
|
|
||||||
|
|
||||||
switch (conf_->orig_wei_dt) {
|
|
||||||
case data_type::s8: uni_vpmovsxbd(vmm_in, op); break;
|
|
||||||
case data_type::u8: uni_vpmovzxbd(vmm_in, op); break;
|
|
||||||
// For int4, we see two int4 as one int8 and extend them int32
|
|
||||||
// low half stores in lower bytes of vmm and high half in higher
|
|
||||||
// bytes of vmm, then permute them into correct order
|
|
||||||
// Finally, we process the extend bytes for s4/u4 accordingly
|
|
||||||
case data_type::s4:
|
|
||||||
uni_vpmovsxbd(vmm_lower, op);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrad(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
case data_type::u4:
|
|
||||||
uni_vpmovzxbd(vmm_lower, op);
|
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
|
||||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
|
||||||
break;
|
|
||||||
default: assert(!"unsupported data type");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Vmm>
|
|
||||||
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_wei_scales(const int blk,
|
|
||||||
const int k, const int n, const bool is_n_tail, const bool is_k_tail) {
|
const int k, const int n, const bool is_n_tail, const bool is_k_tail) {
|
||||||
|
|
||||||
|
if (!req_apply_wei_scales_) return;
|
||||||
|
|
||||||
const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail);
|
const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail);
|
||||||
const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail);
|
const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail);
|
||||||
const auto base_offset = [&](int k) {
|
const auto base_offset = [&](int k) {
|
||||||
@ -5018,31 +5154,15 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::copy_block(
|
|||||||
const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_;
|
const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_;
|
||||||
auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset);
|
auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset);
|
||||||
auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride);
|
auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride);
|
||||||
load_int(src_vmm0, load_addr0);
|
load_value(src_vmm0, load_addr0, vmm_permd, conf_->orig_wei_dt);
|
||||||
load_int(src_vmm1, load_addr1);
|
load_value(src_vmm1, load_addr1, vmm_permd, conf_->orig_wei_dt);
|
||||||
if (req_zp_b_shift_) {
|
|
||||||
vpsubd(src_vmm0, src_vmm0, vmm_zp_b_val);
|
|
||||||
vpsubd(src_vmm1, src_vmm1, vmm_zp_b_val);
|
|
||||||
}
|
|
||||||
vcvtdq2ps(src_vmm0, src_vmm0);
|
|
||||||
vcvtdq2ps(src_vmm1, src_vmm1);
|
|
||||||
if (req_apply_wei_scales_) {
|
|
||||||
const bool is_n_tail = ncolumns - n < n_blk_step;
|
const bool is_n_tail = ncolumns - n < n_blk_step;
|
||||||
const bool is_k_tail = nrows - k < k_blk_step;
|
const bool is_k_tail = nrows - k < k_blk_step;
|
||||||
get_wei_scales(blk, k, n, is_n_tail, is_k_tail);
|
get_scales(blk, k, n, is_n_tail, is_k_tail);
|
||||||
vmulps(src_vmm0, src_vmm0, vmm_wei_scales0);
|
decompress_and_downcvt_2reg(src_vmm0, src_vmm1, vmm_zp_b_val,
|
||||||
vmulps(src_vmm1, src_vmm1, vmm_wei_scales1);
|
vmm_zp_b_val, vmm_wei_scales0, vmm_wei_scales0,
|
||||||
}
|
conf_->orig_wei_dt, conf_->wei_dt);
|
||||||
|
|
||||||
if (conf_->wei_dt == data_type::bf16) {
|
|
||||||
vcvtne2ps2bf16(src_vmm0, src_vmm1, src_vmm0);
|
|
||||||
} else {
|
|
||||||
const auto src_vmm_lower0 = Vmm_lower_t(src_vmm0.getIdx());
|
|
||||||
const auto src_vmm_lower1 = Vmm_lower_t(src_vmm1.getIdx());
|
|
||||||
vcvtps2phx(src_vmm_lower0, src_vmm0);
|
|
||||||
vcvtps2phx(src_vmm_lower1, src_vmm1);
|
|
||||||
vinsertf64x4(src_vmm0, src_vmm0, src_vmm_lower1, 1);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int iter = 0;
|
int iter = 0;
|
||||||
|
Reference in New Issue
Block a user