x64: matmul: Enable AVX512 f32:int4/int8:f32 case

This commit is contained in:
Ovchinnikov Dmitriy
2025-10-14 08:31:58 -07:00
committed by Dmitriy Ovchinnikov
parent f92a20983c
commit 051b020bb1
5 changed files with 96 additions and 129 deletions

View File

@ -171,6 +171,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, f16, f32);
const bool is_f4
= utils::one_of(wei_dt, data_type::f4_e2m1, data_type::f4_e3m0);
const bool is_f32_with_int_wei
= src_dt == f32 && one_of(wei_dt, s8, u8, s4, u4) && dst_dt == f32;
auto check_bias = [&]() -> bool {
const auto bia_dt = weights_md(1)->data_type;
@ -214,8 +216,9 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
// This case requires scratchpad
if (N() == DNNL_RUNTIME_DIM_VAL) ok = false;
}
// This impl supports only f32 scales for non-weight decompression.
if (!(is_bf16_with_int_wei || is_f16_with_int_wei)) {
// Impl suppports f32 scales only for non-weight decompression
if (!(is_bf16_with_int_wei || is_f16_with_int_wei
|| is_f32_with_int_wei)) {
ok = ok && one_of(asc.get_data_type(DNNL_ARG_SRC), undef, f32);
ok = ok && one_of(asc.get_data_type(DNNL_ARG_WEIGHTS), undef, f32);
ok = ok && one_of(asc.get_data_type(DNNL_ARG_DST), undef, f32);
@ -260,7 +263,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
};
const bool problem_dt_correct = one_of(true, is_f4, is_int8, is_f8, is_bf16,
is_f32, is_f16, is_f32_f16, is_f32_bf16, is_bf16_with_int_wei,
is_f16_with_int_wei);
is_f16_with_int_wei, is_f32_with_int_wei);
auto src_d = memory_desc_wrapper(src_md_);
auto weights_d = memory_desc_wrapper(weights_md_);

View File

@ -2134,10 +2134,13 @@ protected:
* @param vmm_permd Vector register containing permutation indices for INT4 processing
* @param dt Data type being loaded
* @param is_tail Flag indicating if tail processing is needed
* @param vmm_f4_lut Vector register containing lookup table for FP4 conversion
* (default is Vmm(4) for kernel jit_brgemm_matmul_copy_b_f32_t)
*/
template <typename Vmm>
void load_value(const Vmm &reg, const Xbyak::Operand &op,
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false) {
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false,
const Vmm &vmm_f4_lut = Vmm(4)) {
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
const auto vmm_in = maybe_mask(reg, is_tail);
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
@ -2196,6 +2199,16 @@ protected:
vpsrld(reg | k5555, reg, 28);
vpsrld(reg | kAAAA, reg, 4);
break;
case data_type::f4_e2m1:
case data_type::f4_e3m0:
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
copy_half_int4(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
vpermps(vmm_in, vmm_in, vmm_f4_lut);
break;
default: assert(!"unsupported data type");
}
}
@ -3665,13 +3678,12 @@ template struct jit_brgemm_matmul_copy_b_bf16_t<Zmm>;
template struct jit_brgemm_matmul_copy_b_bf16_t<Ymm>;
template <typename Vmm>
struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t,
public jit_generator_t {
struct jit_brgemm_matmul_copy_b_f32_t
: public jit_brgemm_matmul_copy_b_common_t {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_f32_t)
jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf)
: jit_brgemm_matmul_copy_b_t(conf)
, jit_generator_t(jit_name())
: jit_brgemm_matmul_copy_b_common_t(conf)
, dt_in_(conf->orig_wei_dt)
, simd_w_(vreg_traits_t<Vmm>::vlen / sizeof(float))
, is_src_f4_(one_of(
@ -3710,12 +3722,6 @@ private:
const size_t typesize_out_ = sizeof(float);
dim_t src_stride_, tr_src_stride_, wei_scales_N_stride_;
opmask_t kTail = k7;
opmask_t kFFFF = k6;
opmask_t k5555 = k5;
opmask_t kAAAA = k4;
opmask_t kTail4bit = k3;
reg64_t reg_src = rax;
reg64_t reg_tr_src = rbx;
@ -3725,7 +3731,6 @@ private:
reg64_t reg_tmp = r15;
reg32_t regw_tmp = r15d;
reg64_t reg_wei_scales = rdx;
reg64_t reg_f4_lut = r11;
Vmm vmm_zero = Vmm(0);
Vmm vmm_wei_scales = Vmm(1);
@ -3739,88 +3744,12 @@ private:
mov(regw_tmp, w);
jit_generator_t::kmovd(k, regw_tmp);
}
void insert_high_half(const Zmm &zmm, const Ymm &ymm_half) {
vinserti64x4(zmm, zmm, ymm_half, 1);
}
void insert_high_half(const Ymm &ymm, const Xmm &xmm_half) {
vinserti128(ymm, ymm, xmm_half, 1);
}
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
assert(is_src_int4_ || is_src_f4_);
return is_tail && isa_has_masks(conf_->isa)
? vmm_lower | kTail4bit | T_z
: vmm_lower;
}
Vmm maybe_mask(Vmm vmm, bool is_tail) {
return is_tail && isa_has_masks(conf_->isa) ? vmm | kTail | T_z : vmm;
}
void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail);
void copy_16_x_n_block(int nrows, int ncolumns);
void compute_k_loop(int ncolumns);
void generate() override;
};
template <typename Vmm>
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::load_data(
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) {
const auto vmm = maybe_mask(vmm_in, is_tail);
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
MAYBE_UNUSED(vmm_lower);
switch (dt_in_) {
case data_type::f32: uni_vmovups(vmm, op); break;
case data_type::bf16:
// Upconvert: load 16 bits and move them 16 bits left.
uni_vpmovzxwd(vmm, op);
uni_vpslld(vmm, vmm, 16);
break;
case data_type::f16:
if (is_superset(conf_->isa, avx512_core_fp16)) {
vcvtph2psx(vmm, op);
} else {
vcvtph2ps(vmm, op);
}
break;
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
// For int4, we see two int4 as one int8 and extend them int32
// low half stores in lower bytes of vmm and high half in higher
// bytes of vmm, then permute them into correct order
// Finally, we process the extend bytes for s4/u4 accordingly
case data_type::s4:
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op);
insert_high_half(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | k5555, vmm_in, 28);
vpsrad(vmm_in | kAAAA, vmm_in, 4);
break;
case data_type::u4:
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
insert_high_half(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
break;
case data_type::f4_e2m1:
case data_type::f4_e3m0:
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
insert_high_half(vmm_in, vmm_lower);
vpermd(vmm_in, vmm_permd, vmm_in);
uni_vpslld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | k5555, vmm_in, 28);
vpsrld(vmm_in | kAAAA, vmm_in, 4);
vpermps(vmm_in, vmm_in, vmm_f4_lut);
break;
default: assert(!"unsupported data type");
}
if (one_of(dt_in_, data_type::s8, data_type::u8, data_type::s4,
data_type::u4))
uni_vcvtdq2ps(vmm_in, vmm_in);
}
template <typename Vmm>
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
int nrows, int ncolumns) {
@ -3846,31 +3775,12 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
if (is_tail && !isa_has_masks(conf_->isa))
vmaskmovps(src_vmm, ymm_tail_mask, addr);
else
load_data(src_vmm, addr, is_tail);
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
if (req_zp_b_shift_)
uni_vsubps(maybe_mask(src_vmm, is_tail), src_vmm, vmm_zp_b_shift);
if (req_apply_wei_scales_) {
const auto wei_scales_addr = maybe_EVEX_compress_addr(
reg_wei_scales,
k * wei_scales_N_stride_ + n * wei_scales_typesize_);
const auto vmm_wei_scales_masked
= maybe_mask(vmm_wei_scales, is_tail);
switch (conf_->wei_scales_dt) {
case data_type::f32:
uni_vmovups(vmm_wei_scales_masked, wei_scales_addr);
break;
case data_type::bf16:
uni_vpmovzxwd(vmm_wei_scales_masked, wei_scales_addr);
uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16);
break;
case data_type::f16:
vcvtph2ps(vmm_wei_scales_masked, wei_scales_addr);
break;
default: assert(!"unsupported wei_scales data type");
}
vmulps(src_vmm, src_vmm, vmm_wei_scales);
}
const auto scales_addr = maybe_EVEX_compress_addr(reg_wei_scales,
k * wei_scales_N_stride_ + n * wei_scales_typesize_);
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
scales_addr, conf_->orig_wei_dt);
};
const int columns_tail = ncolumns % simd_w_;
@ -3881,7 +3791,7 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
if (is_src_int4_ || is_src_f4_) {
const auto tail_mask_4bit
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
kmovw(kTail4bit, tail_mask_4bit);
kmovw(kTail_int4, tail_mask_4bit);
}
} else {
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail);
@ -3967,21 +3877,20 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::generate() {
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
switch (dt_in_) {
case data_type::f4_e2m1:
mov(reg_f4_lut, reinterpret_cast<size_t>(f4_e2m1_table));
mov(reg_tmp, reinterpret_cast<size_t>(f4_e2m1_table));
break;
case data_type::f4_e3m0:
mov(reg_f4_lut, reinterpret_cast<size_t>(f4_e3m0_table));
mov(reg_tmp, reinterpret_cast<size_t>(f4_e3m0_table));
break;
default: break;
}
vmovdqa32(vmm_f4_lut, ptr[reg_f4_lut]);
vmovdqa32(vmm_f4_lut, ptr[reg_tmp]);
}
if (req_zp_b_shift_) {
mov(reg_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
uni_vpbroadcastd(vmm_zp_b_shift, ptr[reg_tmp]);
uni_vcvtdq2ps(vmm_zp_b_shift, vmm_zp_b_shift);
}
Label done;
@ -4441,7 +4350,8 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
auto src_load = is_tail ? src_reg | kTail | T_z : src_reg;
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
const auto addr = EVEX_compress_addr(reg_src, src_offset);
if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) {
if ((conf_->is_f16_with_int_wei || conf_->is_f32_with_int_wei)
&& conf_->wei_dt == data_type::f32) {
const auto xmm_preload = Xmm(src_reg.getIdx());
MAYBE_UNUSED(xmm_preload);

View File

@ -250,6 +250,8 @@ status_t check_isa_with_datatype(
is_superset(isa, avx512_core_bf16))
&& IMPLICATION(bm_conf_utils.is_f16_with_int_wei(),
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16))
&& IMPLICATION(bm_conf_utils.is_f32_with_int_wei(),
one_of(isa, avx512_core, avx2))
&& IMPLICATION(bm_conf_utils.is_f8(),
is_superset(isa, avx512_core_amx_fp16)
|| is_superset(isa, avx10_2_512))
@ -269,7 +271,8 @@ status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
bm_conf_utils.is_f4_via_convert(),
bm_conf_utils.is_tf32(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.is_f32_with_int_wei())
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.with_weights_decompression());
@ -304,11 +307,13 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
&& isa == avx10_2_512_amx_2)
, weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8, u4, s4)
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::f16,
fpmath_mode::any)
fpmath_mode::strict, fpmath_mode::any)
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::f16,
bgmmc.src_dt == f16)
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::bf16,
bgmmc.src_dt == bf16)
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::strict,
bgmmc.src_dt == f32)
&& attr.fpmath_.apply_to_int_)
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
@ -322,6 +327,8 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
&& one_of(bgmmc.dst_dt, bf16, f32))
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, f32_with_int_wei_dt(weights_decompression_support
&& everyone_is(f32, bgmmc.src_dt, bgmmc.dst_dt))
, A_any_layout(A_any_layout)
, B_any_layout(B_any_layout)
, C_any_layout(C_any_layout)
@ -541,7 +548,8 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
|| this->is_f16() || this->is_f32_f16()
|| this->is_f32_bf16()
|| this->is_bf16_with_int_wei()
|| this->is_f16_with_int_wei() || this->is_tf32())
|| this->is_f16_with_int_wei() || this->is_tf32()
|| this->is_f32_with_int_wei())
&& !xf16_avx2_vnni_2;
bgmmc.src_tag = is_adbc_allowed ? memory_desc_matches_one_of_tag(
A_md, plain_tensor_layout_tag,
@ -653,7 +661,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
const bool is_amx_or_avx2_vnni_2 = is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2);
const bool prefer_amx_or_avx2_vnni_2 = is_f16() || is_f32_f16()
|| is_f32_bf16() || is_f16_with_int_wei();
|| is_f32_bf16() || is_f16_with_int_wei() || is_f32_with_int_wei();
if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16()
|| is_bf16_with_int_wei()) {
@ -668,7 +676,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
// Note: bf32 assumes f32 blocking
if (is_f32() || is_bf32() || is_f16() || is_f32_f16() || is_f32_bf16()
|| is_f16_with_int_wei() || is_tf32()) {
|| is_f16_with_int_wei() || is_tf32() || is_f32_with_int_wei()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
@ -1327,6 +1335,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.is_tf32 = bm_conf_utils.is_tf32();
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
bgmmc.is_f32_with_int_wei = bm_conf_utils.is_f32_with_int_wei();
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
bgmmc.is_f32_bf16 = bm_conf_utils.is_f32_bf16();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
@ -1351,6 +1360,11 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if (bm_conf_utils.is_f32_with_int_wei()) {
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2)) {
// Note 1: Keep this branch separately from f16 one to have different
@ -1809,7 +1823,8 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei()) {
|| bm_conf_utils.is_f16_with_int_wei()
|| bm_conf_utils.is_f32_with_int_wei()) {
// empirical observation for performance breakpoint between amx and vnni
// bf16/f16
const dim_t buffer_a_chunk_sz_limit = 126;

View File

@ -217,6 +217,7 @@ struct brgemm_matmul_conf_t {
bool is_bf32 = false;
bool is_bf16_with_int_wei = false;
bool is_f16_with_int_wei = false;
bool is_f32_with_int_wei = false;
bool is_f32_f16 = false;
bool is_f32_bf16 = false;
bool is_int4_weights = false;
@ -282,6 +283,7 @@ struct brgemm_matmul_conf_utils_t {
if (bgmmc.is_runtime_N) return true;
if (bgmmc.is_bf16_with_int_wei) return true;
if (bgmmc.is_f16_with_int_wei) return true;
if (bgmmc.is_f32_with_int_wei) return true;
if (bgmmc.apply_scales_in_buffer_b) return true;
if (bgmmc.is_gemv) return false;
@ -371,6 +373,8 @@ struct brgemm_matmul_conf_utils_t {
inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; }
inline bool is_f32_with_int_wei() const { return f32_with_int_wei_dt; }
inline bool with_weights_decompression() const {
return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8,
data_type::s4, data_type::u4)
@ -412,7 +416,7 @@ private:
const bool f32_dt, bf16_dt, f16_dt, f4_via_convert_dt, f8_dt, bf8_dt,
int8_dt, bf32_dt, tf32_dt;
const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt,
f32_bf16_dt, f16_with_int_wei_dt;
f32_bf16_dt, f16_with_int_wei_dt, f32_with_int_wei_dt;
const bool A_any_layout;
const bool B_any_layout;
const bool C_any_layout;

View File

@ -297,3 +297,38 @@
--attr-zero-points=wei:per_tensor:s4:128x1,src:per_ocic:u4:1x256+wei:per_tensor:s4:128x1
2x3x4x256:2x3x256x64
2x3x6x256:2x3x256x100
# f32 + weight decomression
--reset
--skip-impl=ref
--dt=f32:s4:f32,f32:u4:f32,f32:s8:f32,f32:u8:f32
--wtag=any,ab,ba
--attr-fpmath=strict:true
1x4096:4096x4096
--skip-impl=ref
--dt=f32:s4:f32,f32:u4:f32,f32:s8:f32,f32:u8:f32
--wtag=any,abc,acb
--attr-fpmath=strict:true
2x40x256:2x256x64
7x41x256:1x256x64
3x96x512:3x512x64
3x6x512:1x512x62
--skip-impl=ref
--dt=f32:s4:f32,f32:u4:f32
--wtag=any,ab,ba
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
--attr-zero-points=,wei:common:2
# ,wei:per_oc:s4:4096x1 # PER_OC skips for all matmul implemetations
--attr-fpmath=strict:true
1x4096:4096x4096
--skip-impl=ref
--dt=f32:s8:f32,f32:u8:f32
--wtag=any,ab,ba
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
--attr-zero-points=,wei:common:2
# ,wei:per_ocic:s8:32x1 # PER_OC skips for all matmul implementations
--attr-fpmath=strict:true
1x4096:4096x4096