x64: brgemm: initial fp8 enabling

This commit is contained in:
Soto Flores, Manuel
2024-03-05 23:53:25 -08:00
committed by Manuel Soto
parent b5a0f3a7ce
commit 28e5f46197
3 changed files with 58 additions and 16 deletions

View File

@ -249,8 +249,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
if (utils::everyone_is(
false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16))
if (utils::everyone_is(false, brg->is_int8, brg->is_bf16, brg->is_f32,
brg->is_f16, brg->is_fp8))
return status::unimplemented;
// Only amx_int8 kernel supports u8 weights.
@ -319,6 +319,10 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
is_superset(brg->isa_impl, avx512_core_fp16)
|| is_superset(brg->isa_impl, avx2_vnni_2)))
return status::unimplemented;
if (!IMPLICATION(one_of(data_type::f8_e5m2, dt_bias, dt_d)
|| one_of(data_type::f8_e4m3, dt_bias, dt_d),
mayiuse(avx512_core_amx_fp16)))
return status::unimplemented;
// check that combination of data types is allowed
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
@ -340,6 +344,17 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
&& one_of(dt_bias, data_type::undef, data_type::f32,
data_type::f16)))
return status::unimplemented;
const auto bias_f8_e5m2_compatible
= one_of(dt_d, data_type::f32, data_type::f8_e5m2)
&& one_of(dt_bias, data_type::undef, data_type::f32,
data_type::f8_e5m2);
const auto bias_f8_e4m3_compatible
= one_of(dt_d, data_type::f32, data_type::f8_e4m3)
&& one_of(dt_bias, data_type::undef, data_type::f32,
data_type::f8_e4m3);
if (!IMPLICATION(brg->is_fp8,
bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
return status::unimplemented;
brg->dt_d = dt_d;
brg->typesize_D = types::data_type_size(brg->dt_d);
@ -541,6 +556,10 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
&& brg->prfC.dist2 < 0)
brg->prfC.dist2 = 0;
// TODO: update conditions once other brgemm implementations are enabled
// Currently, fp8 via AMX f16 convert only supported in non-unrolled kernel
if (brg->is_fp8 && brg->brgattr.use_uker) return status::unimplemented;
return status::success;
}
@ -597,7 +616,8 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
//TODO: Add support of tail processing by reduction dimension
auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
if (brg.is_bf32) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
if (brg.is_input_convert())
rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
palette_config_t *buff = (palette_config_t *)(palette);
@ -605,8 +625,10 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
for (int i = 0; i < max_palette_size_in_bytes; i++)
_tc[i] = 0;
const int typesize_A = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_A;
const int typesize_B = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_B;
const int typesize_A
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_A;
const int typesize_B
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_B;
const int rd_step = 4 / typesize_A;

View File

@ -265,6 +265,7 @@ struct brgemm_t {
bool is_tmm = false;
bool is_int8 = false, is_int8_tmm = false;
bool is_bf16 = false, is_bf16_tmm = false, is_bf16_emu = false;
bool is_fp8 = false, is_fp8_tmm = false;
bool is_f16 = false, is_f16_tmm = false;
bool is_f32 = false;
bool is_bf32 = false;
@ -295,6 +296,13 @@ struct brgemm_t {
const primitive_attr_t *attr() const { return attr_; };
const memory_desc_t *dst_md() const { return dst_md_; };
// return 'true' when FP8 MAC is not natively supported by the CPU ISA
bool is_fp8_via_convert() const {
return is_fp8 && utils::one_of(isa_impl, avx10_1_512_amx_fp16);
}
bool is_input_convert() const { return is_bf32 || is_fp8_via_convert(); }
bool is_row_major() const {
assert(layout != brgemm_layout_undef);
return layout == brgemm_row_major;
@ -355,7 +363,7 @@ struct brgemm_t {
if (is_tmm) {
constexpr int tilesize = 1024;
sz = get_num_C_tiles() * tilesize; // postops buffer
if (is_bf32) {
if (is_input_convert()) {
const int n_bdb = bd_block2;
const int n_rdb = rdb + (rdb_tail != 0);
const int n_ldb = ldb + (ldb_tail != 0);

View File

@ -42,7 +42,8 @@ enum {
impl::data_type_t get_accum_datatype(brgemm_t *brg) {
// this assert should check if 'init_kernel_datatype()' was previously
// called.
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
|| brg->is_fp8);
return brg->is_int8 ? data_type::s32 : data_type::f32;
}
@ -54,7 +55,10 @@ void init_kernel_datatype(
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
&& one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
|| brg->is_fp8);
}
void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
@ -145,12 +149,15 @@ void set_isa_impl(brgemm_t *brg) {
avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni,
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni);
} else if (brg->is_fp8) {
brg->isa_impl = utils::map(true, isa_undef,
is_isa_ok(avx10_1_512_amx_fp16), avx10_1_512_amx_fp16);
}
}
void set_brg_vmm(brgemm_t *brg) {
brg->is_tmm = brg->is_int8_tmm || brg->is_bf16_tmm || brg->is_f16_tmm
|| brg->is_bf32;
|| brg->is_bf32 || brg->is_fp8_tmm;
brg->is_zmm = !brg->is_tmm && mayiuse(avx512_core)
&& is_superset(brg->isa_impl, avx512_core);
brg->is_ymm
@ -672,11 +679,10 @@ status_t brgemm_blocking(brgemm_t *brg) {
brg->load_nt_B
= (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true);
const auto max_rd_block
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 32
: 64;
const auto rd_block_step
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 2 : 4;
const bool reduce_by_words = brg->is_bf16_tmm || brg->is_f16_tmm
|| brg->is_input_convert();
const auto max_rd_block = reduce_by_words ? 32 : 64;
const auto rd_block_step = reduce_by_words ? 2 : 4;
// TODO: if rd_block calculated is very small then maybe it makes
// sense to use 1x2 or 2x1 blocking with supporting rd_block
// and rdb_tail
@ -692,14 +698,18 @@ status_t brgemm_blocking(brgemm_t *brg) {
// Remove these guards in the future (add tail processing by reduction
// dimension)
if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, brg->is_bf32))
// TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs
if (!IMPLICATION(
brg->rdb > 0 && brg->rdb_tail, brg->is_input_convert())) {
return status::unimplemented;
}
if (!IMPLICATION(
(brg->rdb_tail
% ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4))
!= 0,
brg->is_bf32))
brg->is_input_convert())) {
return status::unimplemented;
}
//TODO: check this condition
brg->interleave_tilestores_ = brg->beta == 0
@ -822,6 +832,8 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
brg->is_bf32 = is_bf32
&& utils::one_of(brg->isa_user, isa_undef, avx512_core_amx)
&& mayiuse(avx512_core_amx);
brg->is_fp8_tmm
= brg->is_fp8 && one_of(brg->isa_impl, avx512_core_amx_fp16);
brg->has_int8_vnni = isa_has_int8_vnni(brg->isa_impl);