mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
x64: matmul: Enable AVX512 f32:int4/int8:f32 case
This commit is contained in:
committed by
Dmitriy Ovchinnikov
parent
f92a20983c
commit
051b020bb1
@ -171,6 +171,8 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
&& one_of(wei_dt, s8, u8, s4, u4) && one_of(dst_dt, f16, f32);
|
||||
const bool is_f4
|
||||
= utils::one_of(wei_dt, data_type::f4_e2m1, data_type::f4_e3m0);
|
||||
const bool is_f32_with_int_wei
|
||||
= src_dt == f32 && one_of(wei_dt, s8, u8, s4, u4) && dst_dt == f32;
|
||||
|
||||
auto check_bias = [&]() -> bool {
|
||||
const auto bia_dt = weights_md(1)->data_type;
|
||||
@ -214,8 +216,9 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
// This case requires scratchpad
|
||||
if (N() == DNNL_RUNTIME_DIM_VAL) ok = false;
|
||||
}
|
||||
// This impl supports only f32 scales for non-weight decompression.
|
||||
if (!(is_bf16_with_int_wei || is_f16_with_int_wei)) {
|
||||
// Impl suppports f32 scales only for non-weight decompression
|
||||
if (!(is_bf16_with_int_wei || is_f16_with_int_wei
|
||||
|| is_f32_with_int_wei)) {
|
||||
ok = ok && one_of(asc.get_data_type(DNNL_ARG_SRC), undef, f32);
|
||||
ok = ok && one_of(asc.get_data_type(DNNL_ARG_WEIGHTS), undef, f32);
|
||||
ok = ok && one_of(asc.get_data_type(DNNL_ARG_DST), undef, f32);
|
||||
@ -260,7 +263,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
};
|
||||
const bool problem_dt_correct = one_of(true, is_f4, is_int8, is_f8, is_bf16,
|
||||
is_f32, is_f16, is_f32_f16, is_f32_bf16, is_bf16_with_int_wei,
|
||||
is_f16_with_int_wei);
|
||||
is_f16_with_int_wei, is_f32_with_int_wei);
|
||||
|
||||
auto src_d = memory_desc_wrapper(src_md_);
|
||||
auto weights_d = memory_desc_wrapper(weights_md_);
|
||||
|
@ -2134,10 +2134,13 @@ protected:
|
||||
* @param vmm_permd Vector register containing permutation indices for INT4 processing
|
||||
* @param dt Data type being loaded
|
||||
* @param is_tail Flag indicating if tail processing is needed
|
||||
* @param vmm_f4_lut Vector register containing lookup table for FP4 conversion
|
||||
* (default is Vmm(4) for kernel jit_brgemm_matmul_copy_b_f32_t)
|
||||
*/
|
||||
template <typename Vmm>
|
||||
void load_value(const Vmm ®, const Xbyak::Operand &op,
|
||||
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false) {
|
||||
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false,
|
||||
const Vmm &vmm_f4_lut = Vmm(4)) {
|
||||
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
||||
const auto vmm_in = maybe_mask(reg, is_tail);
|
||||
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
|
||||
@ -2196,6 +2199,16 @@ protected:
|
||||
vpsrld(reg | k5555, reg, 28);
|
||||
vpsrld(reg | kAAAA, reg, 4);
|
||||
break;
|
||||
case data_type::f4_e2m1:
|
||||
case data_type::f4_e3m0:
|
||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
||||
copy_half_int4(vmm_in, vmm_lower);
|
||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
||||
vpermps(vmm_in, vmm_in, vmm_f4_lut);
|
||||
break;
|
||||
default: assert(!"unsupported data type");
|
||||
}
|
||||
}
|
||||
@ -3665,13 +3678,12 @@ template struct jit_brgemm_matmul_copy_b_bf16_t<Zmm>;
|
||||
template struct jit_brgemm_matmul_copy_b_bf16_t<Ymm>;
|
||||
|
||||
template <typename Vmm>
|
||||
struct jit_brgemm_matmul_copy_b_f32_t : public jit_brgemm_matmul_copy_b_t,
|
||||
public jit_generator_t {
|
||||
struct jit_brgemm_matmul_copy_b_f32_t
|
||||
: public jit_brgemm_matmul_copy_b_common_t {
|
||||
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_f32_t)
|
||||
|
||||
jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf)
|
||||
: jit_brgemm_matmul_copy_b_t(conf)
|
||||
, jit_generator_t(jit_name())
|
||||
: jit_brgemm_matmul_copy_b_common_t(conf)
|
||||
, dt_in_(conf->orig_wei_dt)
|
||||
, simd_w_(vreg_traits_t<Vmm>::vlen / sizeof(float))
|
||||
, is_src_f4_(one_of(
|
||||
@ -3710,12 +3722,6 @@ private:
|
||||
const size_t typesize_out_ = sizeof(float);
|
||||
dim_t src_stride_, tr_src_stride_, wei_scales_N_stride_;
|
||||
|
||||
opmask_t kTail = k7;
|
||||
opmask_t kFFFF = k6;
|
||||
opmask_t k5555 = k5;
|
||||
opmask_t kAAAA = k4;
|
||||
opmask_t kTail4bit = k3;
|
||||
|
||||
reg64_t reg_src = rax;
|
||||
reg64_t reg_tr_src = rbx;
|
||||
|
||||
@ -3725,7 +3731,6 @@ private:
|
||||
reg64_t reg_tmp = r15;
|
||||
reg32_t regw_tmp = r15d;
|
||||
reg64_t reg_wei_scales = rdx;
|
||||
reg64_t reg_f4_lut = r11;
|
||||
|
||||
Vmm vmm_zero = Vmm(0);
|
||||
Vmm vmm_wei_scales = Vmm(1);
|
||||
@ -3739,88 +3744,12 @@ private:
|
||||
mov(regw_tmp, w);
|
||||
jit_generator_t::kmovd(k, regw_tmp);
|
||||
}
|
||||
void insert_high_half(const Zmm &zmm, const Ymm &ymm_half) {
|
||||
vinserti64x4(zmm, zmm, ymm_half, 1);
|
||||
}
|
||||
void insert_high_half(const Ymm &ymm, const Xmm &xmm_half) {
|
||||
vinserti128(ymm, ymm, xmm_half, 1);
|
||||
}
|
||||
Vmm_lower_t maybe_mask(Vmm_lower_t vmm_lower, bool is_tail) {
|
||||
assert(is_src_int4_ || is_src_f4_);
|
||||
return is_tail && isa_has_masks(conf_->isa)
|
||||
? vmm_lower | kTail4bit | T_z
|
||||
: vmm_lower;
|
||||
}
|
||||
Vmm maybe_mask(Vmm vmm, bool is_tail) {
|
||||
return is_tail && isa_has_masks(conf_->isa) ? vmm | kTail | T_z : vmm;
|
||||
}
|
||||
void load_data(const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail);
|
||||
|
||||
void copy_16_x_n_block(int nrows, int ncolumns);
|
||||
void compute_k_loop(int ncolumns);
|
||||
void generate() override;
|
||||
};
|
||||
|
||||
template <typename Vmm>
|
||||
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::load_data(
|
||||
const Vmm vmm_in, const Xbyak::Operand &op, bool is_tail) {
|
||||
const auto vmm = maybe_mask(vmm_in, is_tail);
|
||||
const auto vmm_lower = Vmm_lower_t(vmm.getIdx());
|
||||
MAYBE_UNUSED(vmm_lower);
|
||||
|
||||
switch (dt_in_) {
|
||||
case data_type::f32: uni_vmovups(vmm, op); break;
|
||||
case data_type::bf16:
|
||||
// Upconvert: load 16 bits and move them 16 bits left.
|
||||
uni_vpmovzxwd(vmm, op);
|
||||
uni_vpslld(vmm, vmm, 16);
|
||||
break;
|
||||
case data_type::f16:
|
||||
if (is_superset(conf_->isa, avx512_core_fp16)) {
|
||||
vcvtph2psx(vmm, op);
|
||||
} else {
|
||||
vcvtph2ps(vmm, op);
|
||||
}
|
||||
break;
|
||||
case data_type::s8: uni_vpmovsxbd(vmm, op); break;
|
||||
case data_type::u8: uni_vpmovzxbd(vmm, op); break;
|
||||
// For int4, we see two int4 as one int8 and extend them int32
|
||||
// low half stores in lower bytes of vmm and high half in higher
|
||||
// bytes of vmm, then permute them into correct order
|
||||
// Finally, we process the extend bytes for s4/u4 accordingly
|
||||
case data_type::s4:
|
||||
uni_vpmovsxbd(maybe_mask(vmm_lower, is_tail), op);
|
||||
insert_high_half(vmm_in, vmm_lower);
|
||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrad(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrad(vmm_in | kAAAA, vmm_in, 4);
|
||||
break;
|
||||
case data_type::u4:
|
||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
||||
insert_high_half(vmm_in, vmm_lower);
|
||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
||||
break;
|
||||
case data_type::f4_e2m1:
|
||||
case data_type::f4_e3m0:
|
||||
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
||||
insert_high_half(vmm_in, vmm_lower);
|
||||
vpermd(vmm_in, vmm_permd, vmm_in);
|
||||
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | k5555, vmm_in, 28);
|
||||
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
||||
vpermps(vmm_in, vmm_in, vmm_f4_lut);
|
||||
break;
|
||||
default: assert(!"unsupported data type");
|
||||
}
|
||||
|
||||
if (one_of(dt_in_, data_type::s8, data_type::u8, data_type::s4,
|
||||
data_type::u4))
|
||||
uni_vcvtdq2ps(vmm_in, vmm_in);
|
||||
}
|
||||
|
||||
template <typename Vmm>
|
||||
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
||||
int nrows, int ncolumns) {
|
||||
@ -3846,31 +3775,12 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
||||
if (is_tail && !isa_has_masks(conf_->isa))
|
||||
vmaskmovps(src_vmm, ymm_tail_mask, addr);
|
||||
else
|
||||
load_data(src_vmm, addr, is_tail);
|
||||
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
||||
|
||||
if (req_zp_b_shift_)
|
||||
uni_vsubps(maybe_mask(src_vmm, is_tail), src_vmm, vmm_zp_b_shift);
|
||||
if (req_apply_wei_scales_) {
|
||||
const auto wei_scales_addr = maybe_EVEX_compress_addr(
|
||||
reg_wei_scales,
|
||||
k * wei_scales_N_stride_ + n * wei_scales_typesize_);
|
||||
const auto vmm_wei_scales_masked
|
||||
= maybe_mask(vmm_wei_scales, is_tail);
|
||||
switch (conf_->wei_scales_dt) {
|
||||
case data_type::f32:
|
||||
uni_vmovups(vmm_wei_scales_masked, wei_scales_addr);
|
||||
break;
|
||||
case data_type::bf16:
|
||||
uni_vpmovzxwd(vmm_wei_scales_masked, wei_scales_addr);
|
||||
uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16);
|
||||
break;
|
||||
case data_type::f16:
|
||||
vcvtph2ps(vmm_wei_scales_masked, wei_scales_addr);
|
||||
break;
|
||||
default: assert(!"unsupported wei_scales data type");
|
||||
}
|
||||
vmulps(src_vmm, src_vmm, vmm_wei_scales);
|
||||
}
|
||||
const auto scales_addr = maybe_EVEX_compress_addr(reg_wei_scales,
|
||||
k * wei_scales_N_stride_ + n * wei_scales_typesize_);
|
||||
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
|
||||
scales_addr, conf_->orig_wei_dt);
|
||||
};
|
||||
|
||||
const int columns_tail = ncolumns % simd_w_;
|
||||
@ -3881,7 +3791,7 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
||||
if (is_src_int4_ || is_src_f4_) {
|
||||
const auto tail_mask_4bit
|
||||
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
|
||||
kmovw(kTail4bit, tail_mask_4bit);
|
||||
kmovw(kTail_int4, tail_mask_4bit);
|
||||
}
|
||||
} else {
|
||||
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail);
|
||||
@ -3967,21 +3877,20 @@ void jit_brgemm_matmul_copy_b_f32_t<Vmm>::generate() {
|
||||
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
|
||||
switch (dt_in_) {
|
||||
case data_type::f4_e2m1:
|
||||
mov(reg_f4_lut, reinterpret_cast<size_t>(f4_e2m1_table));
|
||||
mov(reg_tmp, reinterpret_cast<size_t>(f4_e2m1_table));
|
||||
break;
|
||||
case data_type::f4_e3m0:
|
||||
mov(reg_f4_lut, reinterpret_cast<size_t>(f4_e3m0_table));
|
||||
mov(reg_tmp, reinterpret_cast<size_t>(f4_e3m0_table));
|
||||
break;
|
||||
|
||||
default: break;
|
||||
}
|
||||
vmovdqa32(vmm_f4_lut, ptr[reg_f4_lut]);
|
||||
vmovdqa32(vmm_f4_lut, ptr[reg_tmp]);
|
||||
}
|
||||
|
||||
if (req_zp_b_shift_) {
|
||||
mov(reg_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
||||
uni_vpbroadcastd(vmm_zp_b_shift, ptr[reg_tmp]);
|
||||
uni_vcvtdq2ps(vmm_zp_b_shift, vmm_zp_b_shift);
|
||||
}
|
||||
|
||||
Label done;
|
||||
@ -4441,7 +4350,8 @@ void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
||||
auto src_load = is_tail ? src_reg | kTail | T_z : src_reg;
|
||||
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
||||
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
||||
if (conf_->is_f16_with_int_wei && conf_->wei_dt == data_type::f32) {
|
||||
if ((conf_->is_f16_with_int_wei || conf_->is_f32_with_int_wei)
|
||||
&& conf_->wei_dt == data_type::f32) {
|
||||
const auto xmm_preload = Xmm(src_reg.getIdx());
|
||||
|
||||
MAYBE_UNUSED(xmm_preload);
|
||||
|
@ -250,6 +250,8 @@ status_t check_isa_with_datatype(
|
||||
is_superset(isa, avx512_core_bf16))
|
||||
&& IMPLICATION(bm_conf_utils.is_f16_with_int_wei(),
|
||||
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16))
|
||||
&& IMPLICATION(bm_conf_utils.is_f32_with_int_wei(),
|
||||
one_of(isa, avx512_core, avx2))
|
||||
&& IMPLICATION(bm_conf_utils.is_f8(),
|
||||
is_superset(isa, avx512_core_amx_fp16)
|
||||
|| is_superset(isa, avx10_2_512))
|
||||
@ -269,7 +271,8 @@ status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
|
||||
bm_conf_utils.is_f4_via_convert(),
|
||||
bm_conf_utils.is_tf32(),
|
||||
bm_conf_utils.is_bf16_with_int_wei(),
|
||||
bm_conf_utils.is_f16_with_int_wei())
|
||||
bm_conf_utils.is_f16_with_int_wei(),
|
||||
bm_conf_utils.is_f32_with_int_wei())
|
||||
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|
||||
|| bm_conf_utils.is_f16_with_int_wei(),
|
||||
bm_conf_utils.with_weights_decompression());
|
||||
@ -304,11 +307,13 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
|
||||
&& isa == avx10_2_512_amx_2)
|
||||
, weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8, u4, s4)
|
||||
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::f16,
|
||||
fpmath_mode::any)
|
||||
fpmath_mode::strict, fpmath_mode::any)
|
||||
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::f16,
|
||||
bgmmc.src_dt == f16)
|
||||
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::bf16,
|
||||
bgmmc.src_dt == bf16)
|
||||
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::strict,
|
||||
bgmmc.src_dt == f32)
|
||||
&& attr.fpmath_.apply_to_int_)
|
||||
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
|
||||
&& one_of(bgmmc.dst_dt, bf16, f32))
|
||||
@ -322,6 +327,8 @@ brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
|
||||
&& one_of(bgmmc.dst_dt, bf16, f32))
|
||||
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
|
||||
&& one_of(bgmmc.dst_dt, f16, f32))
|
||||
, f32_with_int_wei_dt(weights_decompression_support
|
||||
&& everyone_is(f32, bgmmc.src_dt, bgmmc.dst_dt))
|
||||
, A_any_layout(A_any_layout)
|
||||
, B_any_layout(B_any_layout)
|
||||
, C_any_layout(C_any_layout)
|
||||
@ -541,7 +548,8 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
|
||||
|| this->is_f16() || this->is_f32_f16()
|
||||
|| this->is_f32_bf16()
|
||||
|| this->is_bf16_with_int_wei()
|
||||
|| this->is_f16_with_int_wei() || this->is_tf32())
|
||||
|| this->is_f16_with_int_wei() || this->is_tf32()
|
||||
|| this->is_f32_with_int_wei())
|
||||
&& !xf16_avx2_vnni_2;
|
||||
bgmmc.src_tag = is_adbc_allowed ? memory_desc_matches_one_of_tag(
|
||||
A_md, plain_tensor_layout_tag,
|
||||
@ -653,7 +661,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
|
||||
const bool is_amx_or_avx2_vnni_2 = is_superset(bgmmc.isa, avx512_core_amx)
|
||||
|| is_superset(bgmmc.isa, avx2_vnni_2);
|
||||
const bool prefer_amx_or_avx2_vnni_2 = is_f16() || is_f32_f16()
|
||||
|| is_f32_bf16() || is_f16_with_int_wei();
|
||||
|| is_f32_bf16() || is_f16_with_int_wei() || is_f32_with_int_wei();
|
||||
|
||||
if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16()
|
||||
|| is_bf16_with_int_wei()) {
|
||||
@ -668,7 +676,7 @@ format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
|
||||
|
||||
// Note: bf32 assumes f32 blocking
|
||||
if (is_f32() || is_bf32() || is_f16() || is_f32_f16() || is_f32_bf16()
|
||||
|| is_f16_with_int_wei() || is_tf32()) {
|
||||
|| is_f16_with_int_wei() || is_tf32() || is_f32_with_int_wei()) {
|
||||
switch (n_blk) {
|
||||
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
|
||||
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
|
||||
@ -1327,6 +1335,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
bgmmc.is_tf32 = bm_conf_utils.is_tf32();
|
||||
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
|
||||
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
|
||||
bgmmc.is_f32_with_int_wei = bm_conf_utils.is_f32_with_int_wei();
|
||||
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
|
||||
bgmmc.is_f32_bf16 = bm_conf_utils.is_f32_bf16();
|
||||
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
|
||||
@ -1351,6 +1360,11 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
bgmmc.wei_dt = f32;
|
||||
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
|
||||
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
|
||||
} else if (bm_conf_utils.is_f32_with_int_wei()) {
|
||||
bgmmc.src_dt = f32;
|
||||
bgmmc.wei_dt = f32;
|
||||
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
|
||||
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
|
||||
} else if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
|
||||
&& is_superset(bgmmc.isa, avx2)) {
|
||||
// Note 1: Keep this branch separately from f16 one to have different
|
||||
@ -1809,7 +1823,8 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|
||||
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()
|
||||
|| bm_conf_utils.is_bf16_with_int_wei()
|
||||
|| bm_conf_utils.is_f16_with_int_wei()) {
|
||||
|| bm_conf_utils.is_f16_with_int_wei()
|
||||
|| bm_conf_utils.is_f32_with_int_wei()) {
|
||||
// empirical observation for performance breakpoint between amx and vnni
|
||||
// bf16/f16
|
||||
const dim_t buffer_a_chunk_sz_limit = 126;
|
||||
|
@ -217,6 +217,7 @@ struct brgemm_matmul_conf_t {
|
||||
bool is_bf32 = false;
|
||||
bool is_bf16_with_int_wei = false;
|
||||
bool is_f16_with_int_wei = false;
|
||||
bool is_f32_with_int_wei = false;
|
||||
bool is_f32_f16 = false;
|
||||
bool is_f32_bf16 = false;
|
||||
bool is_int4_weights = false;
|
||||
@ -282,6 +283,7 @@ struct brgemm_matmul_conf_utils_t {
|
||||
if (bgmmc.is_runtime_N) return true;
|
||||
if (bgmmc.is_bf16_with_int_wei) return true;
|
||||
if (bgmmc.is_f16_with_int_wei) return true;
|
||||
if (bgmmc.is_f32_with_int_wei) return true;
|
||||
if (bgmmc.apply_scales_in_buffer_b) return true;
|
||||
if (bgmmc.is_gemv) return false;
|
||||
|
||||
@ -371,6 +373,8 @@ struct brgemm_matmul_conf_utils_t {
|
||||
|
||||
inline bool is_f16_with_int_wei() const { return f16_with_int_wei_dt; }
|
||||
|
||||
inline bool is_f32_with_int_wei() const { return f32_with_int_wei_dt; }
|
||||
|
||||
inline bool with_weights_decompression() const {
|
||||
return !utils::one_of(bgmmc.src_dt, data_type::s8, data_type::u8,
|
||||
data_type::s4, data_type::u4)
|
||||
@ -412,7 +416,7 @@ private:
|
||||
const bool f32_dt, bf16_dt, f16_dt, f4_via_convert_dt, f8_dt, bf8_dt,
|
||||
int8_dt, bf32_dt, tf32_dt;
|
||||
const bool weights_decompression_support, bf16_with_int_wei_dt, f32_f16_dt,
|
||||
f32_bf16_dt, f16_with_int_wei_dt;
|
||||
f32_bf16_dt, f16_with_int_wei_dt, f32_with_int_wei_dt;
|
||||
const bool A_any_layout;
|
||||
const bool B_any_layout;
|
||||
const bool C_any_layout;
|
||||
|
@ -297,3 +297,38 @@
|
||||
--attr-zero-points=wei:per_tensor:s4:128x1,src:per_ocic:u4:1x256+wei:per_tensor:s4:128x1
|
||||
2x3x4x256:2x3x256x64
|
||||
2x3x6x256:2x3x256x100
|
||||
|
||||
# f32 + weight decomression
|
||||
--reset
|
||||
--skip-impl=ref
|
||||
--dt=f32:s4:f32,f32:u4:f32,f32:s8:f32,f32:u8:f32
|
||||
--wtag=any,ab,ba
|
||||
--attr-fpmath=strict:true
|
||||
1x4096:4096x4096
|
||||
|
||||
--skip-impl=ref
|
||||
--dt=f32:s4:f32,f32:u4:f32,f32:s8:f32,f32:u8:f32
|
||||
--wtag=any,abc,acb
|
||||
--attr-fpmath=strict:true
|
||||
2x40x256:2x256x64
|
||||
7x41x256:1x256x64
|
||||
3x96x512:3x512x64
|
||||
3x6x512:1x512x62
|
||||
|
||||
--skip-impl=ref
|
||||
--dt=f32:s4:f32,f32:u4:f32
|
||||
--wtag=any,ab,ba
|
||||
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
|
||||
--attr-zero-points=,wei:common:2
|
||||
# ,wei:per_oc:s4:4096x1 # PER_OC skips for all matmul implemetations
|
||||
--attr-fpmath=strict:true
|
||||
1x4096:4096x4096
|
||||
|
||||
--skip-impl=ref
|
||||
--dt=f32:s8:f32,f32:u8:f32
|
||||
--wtag=any,ab,ba
|
||||
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
|
||||
--attr-zero-points=,wei:common:2
|
||||
# ,wei:per_ocic:s8:32x1 # PER_OC skips for all matmul implementations
|
||||
--attr-fpmath=strict:true
|
||||
1x4096:4096x4096
|
||||
|
Reference in New Issue
Block a user