mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
x64: matmul: Enable AVX2 matmul weight decompression
This commit is contained in:
committed by
Dmitriy Ovchinnikov
parent
0ba6d85080
commit
846ba7c1d4
@ -477,6 +477,11 @@ inline bool isa_has_bf16(cpu_isa_t isa) {
|
|||||||
return is_superset(isa, avx512_core_bf16);
|
return is_superset(isa, avx512_core_bf16);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool isa_has_f16(cpu_isa_t isa) {
|
||||||
|
return is_superset(isa, avx512_core_fp16)
|
||||||
|
|| is_superset(isa, avx10_1_512_amx_fp16);
|
||||||
|
}
|
||||||
|
|
||||||
inline bool isa_has_masks(cpu_isa_t isa) {
|
inline bool isa_has_masks(cpu_isa_t isa) {
|
||||||
return is_superset(isa, avx512_core);
|
return is_superset(isa, avx512_core);
|
||||||
}
|
}
|
||||||
|
@ -2118,7 +2118,7 @@ protected:
|
|||||||
* @param zmm Destination ZMM register where data will be inserted
|
* @param zmm Destination ZMM register where data will be inserted
|
||||||
* @param ymm_half Source YMM register containing data for the upper half
|
* @param ymm_half Source YMM register containing data for the upper half
|
||||||
*/
|
*/
|
||||||
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
void copy_half_reg(const Zmm &zmm, const Ymm &ymm_half) {
|
||||||
vinserti64x4(zmm, zmm, ymm_half, 1);
|
vinserti64x4(zmm, zmm, ymm_half, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2127,13 +2127,88 @@ protected:
|
|||||||
* @param ymm Destination YMM register where data will be inserted
|
* @param ymm Destination YMM register where data will be inserted
|
||||||
* @param xmm_half Source XMM register containing data for the upper half
|
* @param xmm_half Source XMM register containing data for the upper half
|
||||||
*/
|
*/
|
||||||
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
|
void copy_half_reg(const Ymm &ymm, const Xmm &xmm_half) {
|
||||||
vinserti128(ymm, ymm, xmm_half, 1);
|
vinserti128(ymm, ymm, xmm_half, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Restores vmm register containing permute indices for int4 processing.
|
||||||
|
* Due to lack of registers this register is used as temporary in `prepare_loaded_int4`.
|
||||||
|
*/
|
||||||
|
void restore_vmm_permd(const Ymm &vmm_permd) {
|
||||||
|
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
|
||||||
|
= {0, 4, 1, 5, 2, 6, 3, 7};
|
||||||
|
const auto reg_tmp = r15;
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
|
||||||
|
vmovdqa(vmm_permd, ptr[reg_tmp]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loaded in register(half of it) containing bytes (2 int4 values)
|
||||||
|
* The idea is duplicate each byte in two dwords
|
||||||
|
* using `copy_half_reg` and `vpermd`. Then shift bytes left or right
|
||||||
|
* depending on the position (odd/even) and signed/unsigned type.
|
||||||
|
*/
|
||||||
|
template <typename Vmm>
|
||||||
|
void prepare_loaded_int4(
|
||||||
|
const Vmm ®, const Vmm &vmm_permd, const bool is_signed) {
|
||||||
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||||
|
const auto vmm_lower = Vmm_lower_t(reg.getIdx());
|
||||||
|
|
||||||
|
copy_half_reg(reg, vmm_lower);
|
||||||
|
vpermd(reg, vmm_permd, reg);
|
||||||
|
// Without masks int4 is going to use 2 tmp registers
|
||||||
|
// To perform masks using VPAND operator
|
||||||
|
if (!isa_has_masks(conf_->isa)) {
|
||||||
|
// TODO: Unify register usage over kernels
|
||||||
|
const auto mask_vmm = Vmm(conf_->transposed_B ? 13 : 0);
|
||||||
|
const auto tmp_vmm = vmm_permd;
|
||||||
|
// const auto tmp_vmm2 = Vmm(conf_->transposed_B ? 13: 4);
|
||||||
|
// f32 and transposed used the same register for regq_tmp
|
||||||
|
const auto reg_tmp = r15;
|
||||||
|
alignas(64) static constexpr const uint32_t odd_indices[8] = {
|
||||||
|
0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff};
|
||||||
|
alignas(64) static constexpr const uint32_t even_indices[8] = {
|
||||||
|
0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0, 0xffffffff, 0};
|
||||||
|
// Process odd indices
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(even_indices));
|
||||||
|
vmovdqa(mask_vmm, ptr[reg_tmp]);
|
||||||
|
uni_vpand(tmp_vmm, reg, mask_vmm);
|
||||||
|
uni_vpslld(tmp_vmm, tmp_vmm, 28);
|
||||||
|
if (is_signed) {
|
||||||
|
vpsrad(tmp_vmm, tmp_vmm, 28);
|
||||||
|
} else {
|
||||||
|
vpsrld(tmp_vmm, tmp_vmm, 28);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process even indices
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(odd_indices));
|
||||||
|
vmovdqa(mask_vmm, ptr[reg_tmp]);
|
||||||
|
uni_vpand(reg, reg, mask_vmm);
|
||||||
|
if (is_signed) {
|
||||||
|
vpsrad(reg, reg, 4);
|
||||||
|
} else {
|
||||||
|
vpsrld(reg, reg, 4);
|
||||||
|
}
|
||||||
|
// Store result to desired reg
|
||||||
|
vpaddd(reg, reg, tmp_vmm);
|
||||||
|
// Clean tmp regs
|
||||||
|
vpxor(mask_vmm, mask_vmm, mask_vmm);
|
||||||
|
restore_vmm_permd(vmm_permd);
|
||||||
|
} else {
|
||||||
|
uni_vpslld(reg | k5555, reg, 28);
|
||||||
|
if (is_signed) {
|
||||||
|
vpsrad(reg | k5555, reg, 28);
|
||||||
|
vpsrad(reg | kAAAA, reg, 4);
|
||||||
|
} else {
|
||||||
|
vpsrld(reg | k5555, reg, 28);
|
||||||
|
vpsrld(reg | kAAAA, reg, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Loads and converts data of various types into vector registers with appropriate handling
|
* @brief Loads and converts data of various types into vector registers with appropriate handling
|
||||||
*
|
* Integer types s4, u4, s8, u8, s32 will be loaded and converted to s32 during the loading.
|
||||||
|
* Floating point types f16, bf16, f32 will be loaded and converted to f32 during the loading.
|
||||||
* @tparam Vmm Vector register type for computation
|
* @tparam Vmm Vector register type for computation
|
||||||
* @param reg Destination vector register
|
* @param reg Destination vector register
|
||||||
* @param op Source memory operand to load from
|
* @param op Source memory operand to load from
|
||||||
@ -2187,28 +2262,18 @@ protected:
|
|||||||
uni_vpmovsxbd(
|
uni_vpmovsxbd(
|
||||||
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
||||||
op);
|
op);
|
||||||
copy_half_int4(reg, vmm_lower);
|
prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ true);
|
||||||
|
|
||||||
vpermd(reg, vmm_permd, reg);
|
|
||||||
uni_vpslld(reg | k5555, reg, 28);
|
|
||||||
vpsrad(reg | k5555, reg, 28);
|
|
||||||
vpsrad(reg | kAAAA, reg, 4);
|
|
||||||
break;
|
break;
|
||||||
case data_type::u4:
|
case data_type::u4:
|
||||||
uni_vpmovzxbd(
|
uni_vpmovzxbd(
|
||||||
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
||||||
op);
|
op);
|
||||||
copy_half_int4(reg, vmm_lower);
|
prepare_loaded_int4(reg, vmm_permd, /* is_signed = */ false);
|
||||||
|
|
||||||
vpermd(reg, vmm_permd, reg);
|
|
||||||
uni_vpslld(reg | k5555, reg, 28);
|
|
||||||
vpsrld(reg | k5555, reg, 28);
|
|
||||||
vpsrld(reg | kAAAA, reg, 4);
|
|
||||||
break;
|
break;
|
||||||
case data_type::f4_e2m1:
|
case data_type::f4_e2m1:
|
||||||
case data_type::f4_e3m0:
|
case data_type::f4_e3m0:
|
||||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
||||||
copy_half_int4(vmm_in, vmm_lower);
|
copy_half_reg(vmm_in, vmm_lower);
|
||||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
vpermd(vmm_in, vmm_permd, vmm_in);
|
||||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
||||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
vpsrld(vmm_in | k5555, vmm_in, 28);
|
||||||
@ -2415,7 +2480,7 @@ protected:
|
|||||||
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
|
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
|
||||||
vcvtps2phx(src_vmm_lower0, reg1);
|
vcvtps2phx(src_vmm_lower0, reg1);
|
||||||
vcvtps2phx(src_vmm_lower1, reg2);
|
vcvtps2phx(src_vmm_lower1, reg2);
|
||||||
vinsertf64x4(reg1, reg1, src_vmm_lower1, 1);
|
copy_half_reg(reg1, src_vmm_lower1);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case data_type::f32:
|
case data_type::f32:
|
||||||
@ -3496,9 +3561,11 @@ void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
|||||||
auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset);
|
auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset);
|
||||||
if (!isa_has_masks(conf_->isa)) {
|
if (!isa_has_masks(conf_->isa)) {
|
||||||
if (is_tail)
|
if (is_tail)
|
||||||
load_bytes(src_load, load_addr, columns_tail * tr_typesize);
|
load_bytes(src_load, load_addr,
|
||||||
|
columns_tail * tr_typesize / elems_per_byte);
|
||||||
else
|
else
|
||||||
uni_vmovups(src_load, load_addr);
|
uni_vmovups(src_load, load_addr);
|
||||||
|
load_value(src_reg, src_reg, vmm_permd, conf_->orig_wei_dt, false);
|
||||||
} else {
|
} else {
|
||||||
load_value(
|
load_value(
|
||||||
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||||
@ -3867,7 +3934,6 @@ struct jit_brgemm_matmul_copy_b_f32_t
|
|||||||
, req_apply_wei_scales_(conf->apply_scales_in_buffer_b)
|
, req_apply_wei_scales_(conf->apply_scales_in_buffer_b)
|
||||||
, typesize_in_(types::data_type_size(dt_in_))
|
, typesize_in_(types::data_type_size(dt_in_))
|
||||||
, src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1)
|
, src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1)
|
||||||
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
|
||||||
, src_stride_(conf_->copy_B_wei_stride)
|
, src_stride_(conf_->copy_B_wei_stride)
|
||||||
, tr_src_stride_(conf_->LDB * typesize_out_) {}
|
, tr_src_stride_(conf_->LDB * typesize_out_) {}
|
||||||
|
|
||||||
@ -3882,10 +3948,11 @@ private:
|
|||||||
using opmask_t = const Xbyak::Opmask;
|
using opmask_t = const Xbyak::Opmask;
|
||||||
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||||
|
|
||||||
|
static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
|
||||||
const data_type_t dt_in_;
|
const data_type_t dt_in_;
|
||||||
const int simd_w_;
|
const int simd_w_;
|
||||||
const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_;
|
const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_;
|
||||||
const size_t typesize_in_, src_elems_per_byte_, wei_scales_typesize_;
|
const size_t typesize_in_, src_elems_per_byte_;
|
||||||
const size_t typesize_out_ = sizeof(float);
|
const size_t typesize_out_ = sizeof(float);
|
||||||
dim_t src_stride_, tr_src_stride_;
|
dim_t src_stride_, tr_src_stride_;
|
||||||
|
|
||||||
@ -3940,9 +4007,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
|||||||
const bool is_tail = ncolumns - n < simd_w_;
|
const bool is_tail = ncolumns - n < simd_w_;
|
||||||
auto addr = maybe_EVEX_compress_addr(reg_src,
|
auto addr = maybe_EVEX_compress_addr(reg_src,
|
||||||
(k * src_stride_ + n * typesize_in_) / src_elems_per_byte_);
|
(k * src_stride_ + n * typesize_in_) / src_elems_per_byte_);
|
||||||
if (is_tail && !isa_has_masks(conf_->isa))
|
if (is_tail && !isa_has_masks(conf_->isa)) {
|
||||||
vmaskmovps(src_vmm, ymm_tail_mask, addr);
|
load_bytes(src_vmm, addr,
|
||||||
else
|
(ncolumns % simd_w_) * typesize_in_ / src_elems_per_byte_);
|
||||||
|
load_value(src_vmm, src_vmm, vmm_permd, conf_->orig_wei_dt, false);
|
||||||
|
} else
|
||||||
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||||
|
|
||||||
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
|
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
|
||||||
@ -3962,6 +4031,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
|||||||
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
||||||
const auto offset = n * zp_dt_sz / elems_per_byte;
|
const auto offset = n * zp_dt_sz / elems_per_byte;
|
||||||
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
|
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
|
||||||
|
if (is_tail && !isa_has_masks(conf_->isa)) {
|
||||||
|
load_bytes(vmm_zp_b_shift, addr,
|
||||||
|
(ncolumns % simd_w_) * zp_dt_sz / elems_per_byte);
|
||||||
|
load_value(vmm_zp_b_shift, vmm_zp_b_shift, vmm_permd, zp_dt, false);
|
||||||
|
} else
|
||||||
load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail);
|
load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3977,6 +4051,11 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
|||||||
const auto scales_dt_sz = types::data_type_size(scales_dt);
|
const auto scales_dt_sz = types::data_type_size(scales_dt);
|
||||||
const auto offset = n * scales_dt_sz;
|
const auto offset = n * scales_dt_sz;
|
||||||
const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset);
|
const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset);
|
||||||
|
if (is_tail && !isa_has_masks(conf_->isa)) {
|
||||||
|
load_bytes(
|
||||||
|
vmm_wei_scales, addr, (ncolumns % simd_w_) * scales_dt_sz);
|
||||||
|
load_scale_value(vmm_wei_scales, vmm_wei_scales, scales_dt, false);
|
||||||
|
} else
|
||||||
load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail);
|
load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail);
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3990,8 +4069,6 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
|||||||
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
|
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
|
||||||
kmovw(kTail_int4, tail_mask_4bit);
|
kmovw(kTail_int4, tail_mask_4bit);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4056,17 +4133,24 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::generate() {
|
|||||||
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
||||||
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
||||||
kmovw(kFFFF, 0xffff); // 1111111111111111
|
kmovw(kFFFF, 0xffff); // 1111111111111111
|
||||||
|
|
||||||
if (is_src_int4_ || is_src_f4_) {
|
if (is_src_int4_ || is_src_f4_) {
|
||||||
|
kmovw(kAAAA, 0xaaaa);
|
||||||
|
kmovw(k5555, 0x5555);
|
||||||
|
if (is_superset(conf_->isa, avx512_core)) {
|
||||||
alignas(64) static constexpr const uint32_t int4_permute[16]
|
alignas(64) static constexpr const uint32_t int4_permute[16]
|
||||||
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
||||||
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
|
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
|
||||||
vmovdqa32(vmm_permd, ptr[reg_tmp]);
|
vmovdqa32(vmm_permd, ptr[reg_tmp]);
|
||||||
|
} else if (is_superset(conf_->isa, avx2)) {
|
||||||
kmovw(kAAAA, 0xaaaa);
|
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
|
||||||
kmovw(k5555, 0x5555);
|
= {0, 4, 1, 5, 2, 6, 3, 7};
|
||||||
|
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
|
||||||
|
vmovdqa(vmm_permd, ptr[reg_tmp]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (is_src_f4_) {
|
|
||||||
|
|
||||||
|
if (is_src_f4_ && is_superset(conf_->isa, avx512_core)) {
|
||||||
alignas(64) static constexpr const float f4_e2m1_table[16]
|
alignas(64) static constexpr const float f4_e2m1_table[16]
|
||||||
= {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -.5f,
|
= {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -.5f,
|
||||||
-1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
|
-1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
|
||||||
@ -4347,9 +4431,9 @@ private:
|
|||||||
const auto &scales_dt_sz = conf_->wei_scales_dt_sz;
|
const auto &scales_dt_sz = conf_->wei_scales_dt_sz;
|
||||||
const auto offset = n * scales_dt_sz;
|
const auto offset = n * scales_dt_sz;
|
||||||
const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail);
|
const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail);
|
||||||
const auto addr = EVEX_compress_addr(
|
const auto addr = maybe_EVEX_compress_addr(
|
||||||
reg_wei_scales, offset, scales_dt == data_type::f16);
|
reg_wei_scales, offset, scales_dt == data_type::f16);
|
||||||
vpxord(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales);
|
uni_vpxor(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales);
|
||||||
switch (scales_dt) {
|
switch (scales_dt) {
|
||||||
case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break;
|
case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break;
|
||||||
case data_type::bf16:
|
case data_type::bf16:
|
||||||
@ -4383,6 +4467,7 @@ template <typename Vmm>
|
|||||||
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
||||||
const int columns_tail, const bool use_int4_mask) {
|
const int columns_tail, const bool use_int4_mask) {
|
||||||
assert(IMPLICATION(use_int4_mask, is_src_int4_));
|
assert(IMPLICATION(use_int4_mask, is_src_int4_));
|
||||||
|
assert(isa_has_masks(conf_->isa));
|
||||||
if (columns_tail > 0) {
|
if (columns_tail > 0) {
|
||||||
const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_
|
const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_
|
||||||
|| use_bf16_instructions_
|
|| use_bf16_instructions_
|
||||||
@ -4423,7 +4508,7 @@ template <typename Vmm>
|
|||||||
bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
|
bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
|
||||||
const int i, const int columns_tail, const bool is_tail,
|
const int i, const int columns_tail, const bool is_tail,
|
||||||
const dim_t offset) {
|
const dim_t offset) {
|
||||||
const auto addr = EVEX_compress_addr(reg_src, offset);
|
const auto addr = maybe_EVEX_compress_addr(reg_src, offset);
|
||||||
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
|
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
|
||||||
const auto max_shift_sz = 8;
|
const auto max_shift_sz = 8;
|
||||||
if (need_preload_int4) {
|
if (need_preload_int4) {
|
||||||
@ -4437,8 +4522,8 @@ bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
|
|||||||
} else {
|
} else {
|
||||||
const auto xmm_tmp = Xmm(tmp_vmm(3).getIdx());
|
const auto xmm_tmp = Xmm(tmp_vmm(3).getIdx());
|
||||||
load_bytes(xmm_in, addr, load_sz);
|
load_bytes(xmm_in, addr, load_sz);
|
||||||
load_bytes(
|
load_bytes(xmm_tmp, maybe_EVEX_compress_addr(reg_src, offset + 1),
|
||||||
xmm_tmp, EVEX_compress_addr(reg_src, offset + 1), load_sz);
|
load_sz);
|
||||||
vpsrlq(xmm_in, xmm_in, 4);
|
vpsrlq(xmm_in, xmm_in, 4);
|
||||||
vpsllq(xmm_tmp, xmm_tmp, 4);
|
vpsllq(xmm_tmp, xmm_tmp, 4);
|
||||||
vpord(xmm_in, xmm_in, xmm_tmp);
|
vpord(xmm_in, xmm_in, xmm_tmp);
|
||||||
@ -4491,12 +4576,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
jg(general_load); // i < dynamic nrows -> general load
|
jg(general_load); // i < dynamic nrows -> general load
|
||||||
|
|
||||||
// i >= dynamic nrows -> zero out values in src_reg
|
// i >= dynamic nrows -> zero out values in src_reg
|
||||||
vpxord(src_reg, src_reg, src_reg);
|
uni_vpxor(src_reg, src_reg, src_reg);
|
||||||
jmp(load_done);
|
jmp(load_done);
|
||||||
|
|
||||||
L(general_load);
|
L(general_load);
|
||||||
} else if (i >= nrows) {
|
} else if (i >= nrows) {
|
||||||
vpxord(src_reg, src_reg, src_reg);
|
uni_vpxor(src_reg, src_reg, src_reg);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4527,7 +4612,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
assert(!"Unsupported data type in loading");
|
assert(!"Unsupported data type in loading");
|
||||||
|
|
||||||
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
|
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
|
||||||
vpxord(src_reg_next, src_reg_next, src_reg_next);
|
uni_vpxor(src_reg_next, src_reg_next, src_reg_next);
|
||||||
} else {
|
} else {
|
||||||
const auto is_tail = columns_tail > 0;
|
const auto is_tail = columns_tail > 0;
|
||||||
auto src_next_masked = maybe_mask(src_reg_next, is_tail);
|
auto src_next_masked = maybe_mask(src_reg_next, is_tail);
|
||||||
@ -4570,12 +4655,12 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|||||||
jg(general_load); // i < dynamic nrows -> general load
|
jg(general_load); // i < dynamic nrows -> general load
|
||||||
|
|
||||||
// i >= dynamic nrows -> zero out values in src_reg
|
// i >= dynamic nrows -> zero out values in src_reg
|
||||||
vpxord(src_reg, src_reg, src_reg);
|
uni_vpxor(src_reg, src_reg, src_reg);
|
||||||
jmp(load_done);
|
jmp(load_done);
|
||||||
|
|
||||||
L(general_load);
|
L(general_load);
|
||||||
} else if (i >= nrows) {
|
} else if (i >= nrows) {
|
||||||
vpxord(src_reg, src_reg, src_reg);
|
uni_vpxor(src_reg, src_reg, src_reg);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4772,7 +4857,7 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
|
|||||||
jg(general_load); // i < dynamic nrows -> general load
|
jg(general_load); // i < dynamic nrows -> general load
|
||||||
|
|
||||||
// i >= dynamic nrows -> zero out values in src_reg
|
// i >= dynamic nrows -> zero out values in src_reg
|
||||||
vpxord(vmm_src, vmm_src, vmm_src);
|
uni_vpxor(vmm_src, vmm_src, vmm_src);
|
||||||
jmp(load_done);
|
jmp(load_done);
|
||||||
|
|
||||||
L(general_load);
|
L(general_load);
|
||||||
@ -4781,7 +4866,34 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (columns_tail > 0) {
|
const bool is_tail = columns_tail > 0;
|
||||||
|
|
||||||
|
if (conf_->is_f32_with_int_wei) {
|
||||||
|
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
||||||
|
const auto addr = maybe_EVEX_compress_addr(reg_src, src_offset);
|
||||||
|
const auto xmm_preload = Xmm(vmm_src.getIdx());
|
||||||
|
MAYBE_UNUSED(xmm_preload);
|
||||||
|
const bool preloaded_int4 = preload_int4(
|
||||||
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
||||||
|
const auto &src_op = preloaded_int4
|
||||||
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
||||||
|
: static_cast<const Xbyak::Operand &>(addr);
|
||||||
|
|
||||||
|
if (is_tail && !preloaded_int4) {
|
||||||
|
load_bytes(vmm_src, addr,
|
||||||
|
columns_tail * typesize_ / src_elems_per_byte_);
|
||||||
|
load_value(
|
||||||
|
vmm_src, vmm_src, vmm_permd, conf_->orig_wei_dt, false);
|
||||||
|
} else {
|
||||||
|
load_value(vmm_src, src_op, vmm_permd, conf_->orig_wei_dt,
|
||||||
|
is_tail);
|
||||||
|
}
|
||||||
|
|
||||||
|
load_zero_point(i, is_tail);
|
||||||
|
load_scales(i, is_tail);
|
||||||
|
decompress_reg(
|
||||||
|
vmm_src, vmm_zp_b_val, vmm_wei_scales, conf_->orig_wei_dt);
|
||||||
|
} else if (is_tail) {
|
||||||
load_bytes(vmm_src, reg_src, i * src_stride_,
|
load_bytes(vmm_src, reg_src, i * src_stride_,
|
||||||
columns_tail * typesize_);
|
columns_tail * typesize_);
|
||||||
if (use_fp16_instructions_) {
|
if (use_fp16_instructions_) {
|
||||||
@ -4883,7 +4995,10 @@ void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
|
|||||||
const auto src0 = src_vmm(i);
|
const auto src0 = src_vmm(i);
|
||||||
if (do_compute_compensation_)
|
if (do_compute_compensation_)
|
||||||
dot_product(vmm_comp_acc, vmm_comp_mul, src0);
|
dot_product(vmm_comp_acc, vmm_comp_mul, src0);
|
||||||
uni_vmovups(ptr[reg_tr_src + i * tr_src_stride_], src0);
|
const auto addr
|
||||||
|
= maybe_EVEX_compress_addr(reg_tr_src, i * tr_src_stride_);
|
||||||
|
if (columns_tail > 0 && i >= columns_tail) break;
|
||||||
|
uni_vmovups(addr, src0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5051,11 +5166,18 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::generate() {
|
|||||||
kmovw(k0F0F, 0x0f0f);
|
kmovw(k0F0F, 0x0f0f);
|
||||||
kmovw(kF0F0, 0xf0f0);
|
kmovw(kF0F0, 0xf0f0);
|
||||||
}
|
}
|
||||||
if (is_src_int4_ && is_superset(conf_->isa, avx512_core)) {
|
if (is_src_int4_) {
|
||||||
|
if (is_superset(conf_->isa, avx512_core)) {
|
||||||
alignas(64) static constexpr const uint32_t int4_permute[16]
|
alignas(64) static constexpr const uint32_t int4_permute[16]
|
||||||
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
||||||
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute));
|
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute));
|
||||||
vmovdqa32(vmm_permd, ptr[regq_tmp]);
|
vmovdqa32(vmm_permd, ptr[regq_tmp]);
|
||||||
|
} else if (is_superset(conf_->isa, avx2)) {
|
||||||
|
alignas(64) static constexpr const uint32_t int4_permute_avx2[8]
|
||||||
|
= {0, 4, 1, 5, 2, 6, 3, 7};
|
||||||
|
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute_avx2));
|
||||||
|
vmovdqa(vmm_permd, ptr[regq_tmp]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
load_common_zp_value(vmm_zp_b_val, reg_zp_ptr);
|
load_common_zp_value(vmm_zp_b_val, reg_zp_ptr);
|
||||||
|
@ -1413,6 +1413,12 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
|||||||
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
|
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
|
||||||
bgmmc.with_wei_decompression),
|
bgmmc.with_wei_decompression),
|
||||||
VERBOSE_UNSUPPORTED_SCALES_CFG);
|
VERBOSE_UNSUPPORTED_SCALES_CFG);
|
||||||
|
|
||||||
|
// Check if isa has support for f16/bf16 weights scales
|
||||||
|
VCONDCHECK_BG(IMPLICATION(bgmmc.wei_scales_dt == f16, isa_has_f16(isa))
|
||||||
|
&& IMPLICATION(
|
||||||
|
bgmmc.wei_scales_dt == bf16, isa_has_bf16(isa)),
|
||||||
|
VERBOSE_UNSUPPORTED_SCALES_CFG);
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
|
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
|
||||||
|
Reference in New Issue
Block a user