mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
x64: brgemm: initial fp8 enabling
This commit is contained in:
committed by
Manuel Soto
parent
b5a0f3a7ce
commit
28e5f46197
@ -249,8 +249,8 @@ status_t brgemm_desc_init(brgemm_t *brg, cpu_isa_t isa,
|
||||
|
||||
if (M <= 0 || N <= 0 || K <= 0) return status::invalid_arguments;
|
||||
|
||||
if (utils::everyone_is(
|
||||
false, brg->is_int8, brg->is_bf16, brg->is_f32, brg->is_f16))
|
||||
if (utils::everyone_is(false, brg->is_int8, brg->is_bf16, brg->is_f32,
|
||||
brg->is_f16, brg->is_fp8))
|
||||
return status::unimplemented;
|
||||
|
||||
// Only amx_int8 kernel supports u8 weights.
|
||||
@ -319,6 +319,10 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
|
||||
is_superset(brg->isa_impl, avx512_core_fp16)
|
||||
|| is_superset(brg->isa_impl, avx2_vnni_2)))
|
||||
return status::unimplemented;
|
||||
if (!IMPLICATION(one_of(data_type::f8_e5m2, dt_bias, dt_d)
|
||||
|| one_of(data_type::f8_e4m3, dt_bias, dt_d),
|
||||
mayiuse(avx512_core_amx_fp16)))
|
||||
return status::unimplemented;
|
||||
// check that combination of data types is allowed
|
||||
if ((brg->dt_a == data_type::u8 && brg->dt_b == data_type::s8)
|
||||
&& (!one_of(dt_d, data_type::u8, data_type::s8, data_type::s32,
|
||||
@ -340,6 +344,17 @@ status_t brgemm_desc_set_postops(brgemm_t *brg, const primitive_attr_t *attr,
|
||||
&& one_of(dt_bias, data_type::undef, data_type::f32,
|
||||
data_type::f16)))
|
||||
return status::unimplemented;
|
||||
const auto bias_f8_e5m2_compatible
|
||||
= one_of(dt_d, data_type::f32, data_type::f8_e5m2)
|
||||
&& one_of(dt_bias, data_type::undef, data_type::f32,
|
||||
data_type::f8_e5m2);
|
||||
const auto bias_f8_e4m3_compatible
|
||||
= one_of(dt_d, data_type::f32, data_type::f8_e4m3)
|
||||
&& one_of(dt_bias, data_type::undef, data_type::f32,
|
||||
data_type::f8_e4m3);
|
||||
if (!IMPLICATION(brg->is_fp8,
|
||||
bias_f8_e5m2_compatible || bias_f8_e4m3_compatible))
|
||||
return status::unimplemented;
|
||||
|
||||
brg->dt_d = dt_d;
|
||||
brg->typesize_D = types::data_type_size(brg->dt_d);
|
||||
@ -541,6 +556,10 @@ status_t brgemm_desc_set_attr(brgemm_t *brg, const brgemm_attr_t &brgattr) {
|
||||
&& brg->prfC.dist2 < 0)
|
||||
brg->prfC.dist2 = 0;
|
||||
|
||||
// TODO: update conditions once other brgemm implementations are enabled
|
||||
// Currently, fp8 via AMX f16 convert only supported in non-unrolled kernel
|
||||
if (brg->is_fp8 && brg->brgattr.use_uker) return status::unimplemented;
|
||||
|
||||
return status::success;
|
||||
}
|
||||
|
||||
@ -597,7 +616,8 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
|
||||
|
||||
//TODO: Add support of tail processing by reduction dimension
|
||||
auto rd_block = (!brg.rdb && brg.rdb_tail) ? brg.rdb_tail : brg.rd_block;
|
||||
if (brg.is_bf32) rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
|
||||
if (brg.is_input_convert())
|
||||
rd_block = utils::rnd_up(rd_block, 2 /*vnni_granularity*/);
|
||||
|
||||
palette_config_t *buff = (palette_config_t *)(palette);
|
||||
|
||||
@ -605,8 +625,10 @@ status_t brgemm_init_tiles(const brgemm_t &brg, char palette[64]) {
|
||||
for (int i = 0; i < max_palette_size_in_bytes; i++)
|
||||
_tc[i] = 0;
|
||||
|
||||
const int typesize_A = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_A;
|
||||
const int typesize_B = brg.is_bf32 ? sizeof(bfloat16_t) : brg.typesize_B;
|
||||
const int typesize_A
|
||||
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_A;
|
||||
const int typesize_B
|
||||
= brg.is_input_convert() ? sizeof(int16_t) : brg.typesize_B;
|
||||
|
||||
const int rd_step = 4 / typesize_A;
|
||||
|
||||
|
@ -265,6 +265,7 @@ struct brgemm_t {
|
||||
bool is_tmm = false;
|
||||
bool is_int8 = false, is_int8_tmm = false;
|
||||
bool is_bf16 = false, is_bf16_tmm = false, is_bf16_emu = false;
|
||||
bool is_fp8 = false, is_fp8_tmm = false;
|
||||
bool is_f16 = false, is_f16_tmm = false;
|
||||
bool is_f32 = false;
|
||||
bool is_bf32 = false;
|
||||
@ -295,6 +296,13 @@ struct brgemm_t {
|
||||
const primitive_attr_t *attr() const { return attr_; };
|
||||
const memory_desc_t *dst_md() const { return dst_md_; };
|
||||
|
||||
// return 'true' when FP8 MAC is not natively supported by the CPU ISA
|
||||
bool is_fp8_via_convert() const {
|
||||
return is_fp8 && utils::one_of(isa_impl, avx10_1_512_amx_fp16);
|
||||
}
|
||||
|
||||
bool is_input_convert() const { return is_bf32 || is_fp8_via_convert(); }
|
||||
|
||||
bool is_row_major() const {
|
||||
assert(layout != brgemm_layout_undef);
|
||||
return layout == brgemm_row_major;
|
||||
@ -355,7 +363,7 @@ struct brgemm_t {
|
||||
if (is_tmm) {
|
||||
constexpr int tilesize = 1024;
|
||||
sz = get_num_C_tiles() * tilesize; // postops buffer
|
||||
if (is_bf32) {
|
||||
if (is_input_convert()) {
|
||||
const int n_bdb = bd_block2;
|
||||
const int n_rdb = rdb + (rdb_tail != 0);
|
||||
const int n_ldb = ldb + (ldb_tail != 0);
|
||||
|
@ -42,7 +42,8 @@ enum {
|
||||
impl::data_type_t get_accum_datatype(brgemm_t *brg) {
|
||||
// this assert should check if 'init_kernel_datatype()' was previously
|
||||
// called.
|
||||
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
|
||||
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
|
||||
|| brg->is_fp8);
|
||||
return brg->is_int8 ? data_type::s32 : data_type::f32;
|
||||
}
|
||||
|
||||
@ -54,7 +55,10 @@ void init_kernel_datatype(
|
||||
brg->is_bf16 = (dt_a == data_type::bf16) && (dt_b == data_type::bf16);
|
||||
brg->is_f32 = (dt_a == data_type::f32) && (dt_b == data_type::f32);
|
||||
brg->is_f16 = utils::one_of(data_type::f16, dt_a, dt_b);
|
||||
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16);
|
||||
brg->is_fp8 = one_of(dt_a, data_type::f8_e5m2, data_type::f8_e4m3)
|
||||
&& one_of(dt_b, data_type::f8_e5m2, data_type::f8_e4m3);
|
||||
assert(brg->is_int8 || brg->is_bf16 || brg->is_f32 || brg->is_f16
|
||||
|| brg->is_fp8);
|
||||
}
|
||||
|
||||
void init_common_conf(brgemm_t *brg, brgemm_batch_kind_t type, float alpha,
|
||||
@ -145,12 +149,15 @@ void set_isa_impl(brgemm_t *brg) {
|
||||
avx512_core_amx, is_isa_ok(avx512_core_vnni), avx512_core_vnni,
|
||||
is_isa_ok(avx512_core), avx512_core, is_isa_ok(avx2_vnni_2),
|
||||
avx2_vnni_2, is_isa_ok(avx2_vnni), avx2_vnni);
|
||||
} else if (brg->is_fp8) {
|
||||
brg->isa_impl = utils::map(true, isa_undef,
|
||||
is_isa_ok(avx10_1_512_amx_fp16), avx10_1_512_amx_fp16);
|
||||
}
|
||||
}
|
||||
|
||||
void set_brg_vmm(brgemm_t *brg) {
|
||||
brg->is_tmm = brg->is_int8_tmm || brg->is_bf16_tmm || brg->is_f16_tmm
|
||||
|| brg->is_bf32;
|
||||
|| brg->is_bf32 || brg->is_fp8_tmm;
|
||||
brg->is_zmm = !brg->is_tmm && mayiuse(avx512_core)
|
||||
&& is_superset(brg->isa_impl, avx512_core);
|
||||
brg->is_ymm
|
||||
@ -672,11 +679,10 @@ status_t brgemm_blocking(brgemm_t *brg) {
|
||||
brg->load_nt_B
|
||||
= (brg->brgattr.hint_load_nt_B == brgemm_hint_nt_true);
|
||||
|
||||
const auto max_rd_block
|
||||
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 32
|
||||
: 64;
|
||||
const auto rd_block_step
|
||||
= (brg->is_bf16_tmm || brg->is_f16_tmm || brg->is_bf32) ? 2 : 4;
|
||||
const bool reduce_by_words = brg->is_bf16_tmm || brg->is_f16_tmm
|
||||
|| brg->is_input_convert();
|
||||
const auto max_rd_block = reduce_by_words ? 32 : 64;
|
||||
const auto rd_block_step = reduce_by_words ? 2 : 4;
|
||||
// TODO: if rd_block calculated is very small then maybe it makes
|
||||
// sense to use 1x2 or 2x1 blocking with supporting rd_block
|
||||
// and rdb_tail
|
||||
@ -692,14 +698,18 @@ status_t brgemm_blocking(brgemm_t *brg) {
|
||||
|
||||
// Remove these guards in the future (add tail processing by reduction
|
||||
// dimension)
|
||||
if (!IMPLICATION(brg->rdb > 0 && brg->rdb_tail, brg->is_bf32))
|
||||
// TODO: these checks do not work for fp8-f16 and f16-fp8 cfgs
|
||||
if (!IMPLICATION(
|
||||
brg->rdb > 0 && brg->rdb_tail, brg->is_input_convert())) {
|
||||
return status::unimplemented;
|
||||
}
|
||||
if (!IMPLICATION(
|
||||
(brg->rdb_tail
|
||||
% ((brg->is_bf16_tmm || brg->is_f16_tmm) ? 2 : 4))
|
||||
!= 0,
|
||||
brg->is_bf32))
|
||||
brg->is_input_convert())) {
|
||||
return status::unimplemented;
|
||||
}
|
||||
|
||||
//TODO: check this condition
|
||||
brg->interleave_tilestores_ = brg->beta == 0
|
||||
@ -822,6 +832,8 @@ void init_brgemm_conf(brgemm_t *brg, cpu_isa_t isa, brgemm_batch_kind_t type,
|
||||
brg->is_bf32 = is_bf32
|
||||
&& utils::one_of(brg->isa_user, isa_undef, avx512_core_amx)
|
||||
&& mayiuse(avx512_core_amx);
|
||||
brg->is_fp8_tmm
|
||||
= brg->is_fp8 && one_of(brg->isa_impl, avx512_core_amx_fp16);
|
||||
|
||||
brg->has_int8_vnni = isa_has_int8_vnni(brg->isa_impl);
|
||||
|
||||
|
Reference in New Issue
Block a user