Files
oneDNN/src/cpu/x64/matmul/brgemm_matmul_utils.cpp

2208 lines
95 KiB
C++

/*******************************************************************************
* Copyright 2021-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <unordered_set>
#include "common/dnnl_thread.hpp"
#include "cpu/binary_injector_utils.hpp"
#include "cpu/matmul/gemm_based_common.hpp"
#include "cpu/matmul/matmul_utils.hpp"
#include "cpu/platform.hpp"
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#include "cpu/x64/matmul/amx_blocking_heuristics.hpp"
#include "cpu/x64/matmul/brgemm_matmul_utils.hpp"
#include "cpu/x64/matmul/postops_estimator.hpp"
#include "oneapi/dnnl/dnnl_debug.h"
// TODO add a method to print brgemm conf info
#define VCONDCHECK_BG(cond, msg, ...) \
VCONDCHECK(primitive, create, dispatch, brgemm_matmul, (cond), \
status::unimplemented, msg, ##__VA_ARGS__);
#define VCHECK_BG(f, msg, ...) \
VCHECK(primitive, create, dispatch, brgemm_matmul, f, msg, ##__VA_ARGS__);
namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
namespace matmul {
using namespace dnnl::impl::cpu::matmul;
using namespace dnnl::impl::memory_tracking::names;
using namespace dnnl::impl::utils;
using namespace data_type;
using namespace format_tag;
int get_n_block_from_tag(format_tag_t matrix_b_tag) {
// Note: consider using weights mem_descriptor 'inner_blks' to
// return B's inner block for non-default cases.
switch (matrix_b_tag) {
case aCB16b64c:
case aCB16b64c2b:
case aCB16b64c4b:
case BA16a64b4a:
case BA16a64b2a:
case BA16a64b: return 64;
case aCB16b48c:
case aCB16b48c2b:
case aCB16b48c4b:
case BA16a48b:
case BA16a48b2a:
case BA16a48b4a: return 48;
case aCB16b32c:
case aCB16b32c2b:
case aCB16b32c4b:
case BA16a32b:
case BA16a32b2a:
case BA16a32b4a: return 32;
case aCB16b16c:
case aCB16b16c2b:
case aCB16b16c4b:
case BA16a16b:
case BA16a16b2a:
case BA16a16b4a: return 16;
default: return 0;
}
}
int get_wei_k_blk(data_type_t wei_dt) {
// Fixed outer block size.
const int k_outer_block = 16;
// VNNI granularity determines the inner block size along K.
const int k_inner_block = data_type_vnni_granularity(wei_dt);
return k_outer_block * k_inner_block;
}
void mem_advice_init(brgemm_matmul_conf_t &bgmmc) {
dim_t parallel_work_amount = bgmmc.batch * bgmmc.M_chunks * bgmmc.N_chunks;
int nthr_bmn = bgmmc.nthr / bgmmc.nthr_k;
dim_t start {0}, end {0};
balance211(parallel_work_amount, nthr_bmn, 0, start, end);
dim_t nchunks_per_thread = end - start;
// memory advice feature heuristic is based on the performance tests done
// on simulator and lets the tile loading snoop for other cores caches if
// the A/B matrices are shared. thus, if already shared, no need to fetch
// from lower level memories the assumption is that if we don't divide
// the C matrix evenly on row chunks per thread, then it worth checking
// mem advice as there will be sharing
if (bgmmc.is_thread_chunks_exec_order_horizontal) {
bgmmc.mem_advice
= brgemm_kernel_hint_mem_advice_t::brgemm_hint_mem_advice_B;
if (nchunks_per_thread % bgmmc.N_chunks && bgmmc.is_amx)
bgmmc.mem_advice = brgemm_kernel_hint_mem_advice_t::
brgemm_hint_mem_advice_A_B;
} else {
assert(bgmmc.is_thread_chunks_exec_order_horizontal
&& "this mode is not operational at the moment");
bgmmc.mem_advice
= brgemm_kernel_hint_mem_advice_t::brgemm_hint_mem_advice_A;
if (nchunks_per_thread % bgmmc.M_chunks)
bgmmc.mem_advice = brgemm_kernel_hint_mem_advice_t::
brgemm_hint_mem_advice_A_B;
}
}
// TODO: add support of post-ops with multiple binary and eltwise execution
bool post_ops_ok(brgemm_matmul_conf_t &bgmmc, const primitive_attr_t &attr,
const memory_desc_wrapper &dst_d,
bool limit_bcast_strategies_set = false) {
using namespace injector;
const auto &post_ops = attr.post_ops_;
const auto ndims = dst_d.ndims();
bool is_binary_po_per_oc_sp_bcast {};
bool is_binary_po_per_oc_d_bcast {};
bool is_binary_po_channel_bcast {};
bool is_binary_po_per_mb_bcast {};
bool is_binary_po_per_mb_w_bcast {};
bool is_binary_po_per_w_bcast {};
bool is_binary_po_per_hw_bcast {};
bool is_binary_po_batch_bcast {};
std::tie(is_binary_po_per_oc_sp_bcast, is_binary_po_per_oc_d_bcast,
is_binary_po_channel_bcast, is_binary_po_per_mb_bcast,
is_binary_po_per_mb_w_bcast, is_binary_po_per_w_bcast,
is_binary_po_per_hw_bcast, is_binary_po_batch_bcast)
= binary_injector_utils::bcast_strategies_present_tup(
post_ops.entry_, dst_d,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_oc_d,
broadcasting_strategy_t::per_mb,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w,
broadcasting_strategy_t::per_hw,
broadcasting_strategy_t::per_w,
broadcasting_strategy_t::batch);
const bool supported_binary_bcast
= IMPLICATION(is_binary_po_per_oc_sp_bcast, ndims < 4)
&& IMPLICATION(is_binary_po_per_oc_d_bcast, ndims == 4)
&& IMPLICATION(
is_binary_po_channel_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(
is_binary_po_per_mb_w_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(is_binary_po_per_w_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(
is_binary_po_per_mb_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(is_binary_po_batch_bcast, utils::one_of(ndims, 3, 4))
&& IMPLICATION(is_binary_po_per_hw_bcast, ndims == 4);
const bcast_set_t default_bcast_set = {broadcasting_strategy_t::per_oc,
broadcasting_strategy_t::per_oc_spatial,
broadcasting_strategy_t::per_oc_d, broadcasting_strategy_t::scalar,
broadcasting_strategy_t::per_mb,
broadcasting_strategy_t::per_mb_spatial,
broadcasting_strategy_t::per_mb_w, broadcasting_strategy_t::per_hw,
broadcasting_strategy_t::per_w, broadcasting_strategy_t::batch,
broadcasting_strategy_t::no_broadcast};
const bcast_set_t limited_bcast_set = {broadcasting_strategy_t::scalar,
broadcasting_strategy_t::no_broadcast};
const bcast_set_t bcast_set
= limit_bcast_strategies_set || bgmmc.is_runtime_N
? limited_bcast_set
: default_bcast_set;
// binary post-ops are disabled for runtime N due to issues in
// injector::post_ops_ok() with identification of per_oc strategy
// for some cases
std::vector<post_op_type> accepted_post_ops;
accepted_post_ops.push_back(sum);
accepted_post_ops.push_back(eltwise);
if (!bgmmc.is_runtime_N) accepted_post_ops.push_back(binary);
return supported_binary_bcast
&& injector::post_ops_ok(
post_ops_ok_args_t(get_max_cpu_isa(), accepted_post_ops,
post_ops, &dst_d, false /*sum_at_pos_0_only*/,
false /*sum_requires_scale_one*/,
false /*sum_requires_zp_zero*/,
true /*sum_requires_same_params*/, bcast_set));
}
// Trivial: The motivation is to compute batch offset for a memory
// (src or wei or dst) with minimal overhead, by using
// `batch_offset = b * b_stride`.
// This is possible when the batch layout is contiguous.
bool is_batch_layout_trivial(
const memory_desc_wrapper &mdw, const dim_t batch) {
const int ndims = mdw.ndims();
if (ndims <= 3) return true;
const auto &strides = mdw.strides();
const int batch_start_idx = ndims - 3;
dim_t cur_stride = strides[batch_start_idx];
dim_t min_batch_stride = cur_stride;
dim_t max_batch_stride = cur_stride;
for (int d = batch_start_idx - 1; d >= 0; --d) {
cur_stride = strides[d];
min_batch_stride = nstl::min(min_batch_stride, cur_stride);
max_batch_stride = nstl::max(max_batch_stride, cur_stride);
}
if (min_batch_stride == 0) return false;
return max_batch_stride / min_batch_stride == batch;
}
status_t check_isa_with_datatype(
const cpu_isa_t isa, const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const bool ok
= IMPLICATION(bm_conf_utils.is_f32(),
one_of(isa, avx512_core, avx2) || bm_conf_utils.is_bf32()
|| bm_conf_utils.is_tf32())
&& IMPLICATION(bm_conf_utils.is_int8(),
is_superset(isa, avx512_core)
|| is_superset(isa, avx2_vnni))
&& IMPLICATION(bm_conf_utils.is_bf16(),
one_of(isa, avx512_core_amx, avx512_core_bf16, avx2_vnni_2))
&& IMPLICATION(bm_conf_utils.is_f16(),
one_of(isa, avx10_2_512, avx10_2_512_amx_2,
avx512_core_amx_fp16, avx512_core_fp16,
avx2_vnni_2))
// `avx512_core_amx_fp16` is not supported for plain upconversion
// as HW supports native compute.
&& IMPLICATION(bm_conf_utils.is_f32_f16(),
one_of(isa, avx512_core_fp16, avx2_vnni_2, avx512_core,
avx2))
// `avx512_core_amx` is not supported for plain upconversion as HW
// supports native compute.
&& IMPLICATION(bm_conf_utils.is_f32_bf16(),
one_of(isa, avx512_core_bf16, avx2_vnni_2, avx512_core,
avx2))
&& IMPLICATION(bm_conf_utils.is_int8_with_bf16_dst(),
is_superset(isa, avx512_core) || isa == avx2_vnni_2)
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei(),
is_superset(isa, avx512_core_bf16))
&& IMPLICATION(bm_conf_utils.is_f16_with_int_wei(),
one_of(isa, avx512_core_amx_fp16, avx512_core_fp16))
&& IMPLICATION(bm_conf_utils.is_f8(),
is_superset(isa, avx512_core_amx_fp16)
|| is_superset(isa, avx10_2_512))
&& IMPLICATION(bm_conf_utils.is_bf8(),
is_superset(isa, avx512_core_amx_fp16))
&& IMPLICATION(bm_conf_utils.is_f4_via_convert(),
one_of(isa, avx10_2_512));
return ok ? status::success : status::unimplemented;
}
status_t check_datatype_cfg(const brgemm_matmul_conf_utils_t &bm_conf_utils) {
const bool ok
= one_of(true, bm_conf_utils.is_f32(), bm_conf_utils.is_bf16(),
bm_conf_utils.is_f16(), bm_conf_utils.is_f32_f16(),
bm_conf_utils.is_f32_bf16(), bm_conf_utils.is_bf32(),
bm_conf_utils.is_f8(), bm_conf_utils.is_int8(),
bm_conf_utils.is_f4_via_convert(),
bm_conf_utils.is_tf32(),
bm_conf_utils.is_bf16_with_int_wei(),
bm_conf_utils.is_f16_with_int_wei())
&& IMPLICATION(bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei(),
bm_conf_utils.with_weights_decompression());
return ok ? status::success : status::unimplemented;
}
brgemm_matmul_conf_utils_t::brgemm_matmul_conf_utils_t(
brgemm_matmul_conf_t &bgmmc, const cpu_isa_t isa,
const primitive_attr_t &attr, bool A_any_layout, bool B_any_layout,
bool C_any_layout, bool bias_any_layout)
: bgmmc(bgmmc)
, f32_dt(utils::everyone_is(f32, bgmmc.src_dt, bgmmc.wei_dt, bgmmc.dst_dt))
, bf16_dt(utils::everyone_is(bf16, bgmmc.src_dt, bgmmc.wei_dt)
&& one_of(bgmmc.dst_dt, bf16, f32))
, f16_dt(utils::everyone_is(f16, bgmmc.src_dt, bgmmc.wei_dt)
&& one_of(bgmmc.dst_dt, f16, f32))
, f4_via_convert_dt(utils::one_of(bgmmc.wei_dt, data_type::f4_e2m1,
data_type::f4_e3m0)
&& isa == avx10_1_512)
, f8_dt(one_of(bgmmc.src_dt, f8_e5m2, f8_e4m3)
&& one_of(bgmmc.wei_dt, f8_e5m2, f8_e4m3)
&& one_of(bgmmc.dst_dt, f16, f32, bf16, f8_e5m2, f8_e4m3))
, bf8_dt(everyone_is(f8_e5m2, bgmmc.src_dt, bgmmc.wei_dt)
&& one_of(bgmmc.dst_dt, f16, f32, bf16, f8_e5m2, f8_e4m3))
, int8_dt(utils::one_of(bgmmc.src_dt, u8, s8) && bgmmc.wei_dt == s8
&& one_of(bgmmc.dst_dt, u8, s8, s32, f32, f16, bf16))
, bf32_dt(f32_dt
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::any)
&& isa == avx512_core_amx)
, tf32_dt(f32_dt
&& one_of(attr.fpmath_.mode_, fpmath_mode::tf32, fpmath_mode::any)
&& isa == avx10_2_512_amx_2)
, weights_decompression_support(one_of(bgmmc.wei_dt, u8, s8, u4, s4)
&& one_of(attr.fpmath_.mode_, fpmath_mode::bf16, fpmath_mode::f16,
fpmath_mode::any)
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::f16,
bgmmc.src_dt == f16)
&& IMPLICATION(attr.fpmath_.mode_ == fpmath_mode::bf16,
bgmmc.src_dt == bf16)
&& attr.fpmath_.apply_to_int_)
, bf16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
// Keep this var separate from f16_dt to not slip f16:f16 on avx512_core and
// avx2 as there's no kernel for such combination.
, f32_f16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
// Keep this var separate from bf16_dt to not slip bf16:bf16 on avx512_core
// and avx2 as there's no kernel for such combination.
, f32_bf16_dt(bgmmc.src_dt == f32 && bgmmc.wei_dt == bf16
&& one_of(bgmmc.dst_dt, bf16, f32))
, f16_with_int_wei_dt(weights_decompression_support && bgmmc.src_dt == f16
&& one_of(bgmmc.dst_dt, f16, f32))
, A_any_layout(A_any_layout)
, B_any_layout(B_any_layout)
, C_any_layout(C_any_layout)
, bias_any_layout(bias_any_layout)
, plain_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ab, abc, abcd, abcde,
abcdef, abcdefg, abcdefgh, abcdefghi, abcdefghij, abcdefghijk,
abcdefghijkl))
, transposed_tensor_layout_tag(utils::pick(bgmmc.ndims - 2, ba, acb, abdc,
abced, abcdfe, abcdegf, abcdefhg, abcdefgih, abcdefghji,
abcdefghikj, abcdefghijlk))
, blocked_64n_B_layout_tag(pick_blocked_B_layout(64))
, blocked_48n_B_layout_tag(pick_blocked_B_layout(48))
, blocked_32n_B_layout_tag(pick_blocked_B_layout(32))
, blocked_16n_B_layout_tag(pick_blocked_B_layout(16))
, blocked_B_layouts_allowed(!utils::one_of(format_tag::undef,
blocked_64n_B_layout_tag, blocked_48n_B_layout_tag,
blocked_32n_B_layout_tag, blocked_16n_B_layout_tag))
, n_blk_fixed((!B_any_layout) && blocked_B_layouts_allowed)
, isa_(isa) {}
int brgemm_matmul_conf_utils_t::get_default_n_block(
format_tag_t matrix_b_tag) const {
if (bgmmc.is_gemv) return 1;
const int n_blk = get_n_block_from_tag(matrix_b_tag);
if (n_blk > 0) return n_blk;
if (matmul_amx_blocking_params_macro_t::is_supported(bgmmc, *this)) {
return 32;
}
return 64;
}
/**
* This function selects a compatible format for A if its format is "any".
* Otherwise, it checks if the provided format is compatible.
* It returns a valid format tag on success, or format_tag::undef otherwise.
*/
format_tag_t brgemm_matmul_conf_utils_t::get_gemv_A_tag(
const memory_desc_t &A_md) const {
if (A_any_layout)
return plain_tensor_layout_tag;
else
return memory_desc_matches_one_of_tag(A_md, plain_tensor_layout_tag);
}
/**
* This function selects a compatible format for B if its format is "any".
* Otherwise, it checks if the provided format is compatible.
* It returns a valid format tag on success, or format_tag::undef otherwise.
*/
format_tag_t brgemm_matmul_conf_utils_t::get_gemv_B_tag(
const memory_desc_t &B_md) const {
if (B_any_layout) {
// Plain and transposed layouts are identical for B in GEMV cases,
// so we simply choose the plain one.
return plain_tensor_layout_tag;
} else {
if (B_md.format_kind != format_kind::blocked) return format_tag::undef;
// Elements of B, which is a vector in the case of GEMV, must be
// contiguous in memory.
const bool wei_format_compatible
= B_md.format_desc.blocking.strides[bgmmc.ndims - 2] == 1;
if (!wei_format_compatible) return format_tag::undef;
// TODO: The current matmul design requires us to infer wei_tag, so
// we still need to do that even though the provided format is
// compatible. For now, allow both plain and trans formats. Consider
// removing the need to infer the wei_tag in the future.
return memory_desc_matches_one_of_tag(
B_md, plain_tensor_layout_tag, transposed_tensor_layout_tag);
}
}
/**
* This function checks if a dedicated code path for GEMV is applicable.
* All relevant checks must be performed here to determine whether we
* should fall back to the GEMM code path before initializing format tags
* and other relevant parameters.
*/
bool is_gemv_applicable(const brgemm_matmul_conf_t &bgmmc,
const brgemm_matmul_conf_utils_t &bm_conf_utils,
const memory_desc_t &A_md, const memory_desc_t &B_md) {
// Only the N=1 case is supported currently.
// TODO: The existing GEMV code path could also support the M=1 case when
// B tensor has a transposed format. Enable this by swapping A and B.
if (bgmmc.N != 1) return false;
// Reduction is not supported for GEMV code path.
if (bgmmc.with_reduce) return false;
// BRGEMV currently supports only f32 and AVX2.
if (utils::one_of(false, bm_conf_utils.is_f32(),
bgmmc.isa == avx2 || bgmmc.isa == avx512_core))
return false;
if (utils::one_of(format_tag::undef, bm_conf_utils.get_gemv_A_tag(A_md),
bm_conf_utils.get_gemv_B_tag(B_md)))
return false;
return true;
}
status_t brgemm_matmul_conf_utils_t::set_or_check_B_tag(memory_desc_t &B_md,
const matmul_helper_t &helper, bool init_n_tag) const {
const memory_desc_wrapper B_d(&B_md);
if (B_any_layout) {
if (bgmmc.is_gemv) {
bgmmc.wei_tag = get_gemv_B_tag(B_md);
assert(bgmmc.wei_tag != format_tag::undef
&& "if bgmmc.is_gemv is true the format tag must be defined");
} else {
const int default_n_block = init_n_tag
? get_default_n_block(format_tag::undef)
: bgmmc.N_blk;
bgmmc.wei_tag = blocked_B_layouts_allowed && !bgmmc.is_runtime_N
&& !bgmmc.is_int4_weights
? this->pick_blocked_B_layout(default_n_block)
: bgmmc.is_int4_weights && bgmmc.N % 2 != 0
? transposed_tensor_layout_tag
: plain_tensor_layout_tag;
}
VCONDCHECK_BG(
format_tag::undef != bgmmc.wei_tag, VERBOSE_UNSUPPORTED_TAG)
VCHECK_BG(memory_desc_init_by_tag(B_md, bgmmc.wei_tag),
VERBOSE_UNSUPPORTED_TAG);
const int dmax = nstl::min(bgmmc.ndims, 3);
for (int d = 0; d < dmax; d++) {
int dim = bgmmc.ndims - 1 - d;
bgmmc.B_strides[d]
= bgmmc.b_dt_sz * B_d.blocking_desc().strides[dim];
}
} else {
if (bgmmc.is_gemv) {
bgmmc.wei_tag = get_gemv_B_tag(B_md);
assert(bgmmc.wei_tag != format_tag::undef
&& "if bgmmc.is_gemv is true the format tag must be defined");
} else {
bgmmc.wei_tag = blocked_B_layouts_allowed && !bgmmc.is_runtime_N
&& !bgmmc.is_int4_weights
? memory_desc_matches_one_of_tag(B_md,
plain_tensor_layout_tag,
transposed_tensor_layout_tag,
blocked_64n_B_layout_tag, blocked_48n_B_layout_tag,
blocked_32n_B_layout_tag, blocked_16n_B_layout_tag)
: memory_desc_matches_one_of_tag(B_md,
plain_tensor_layout_tag,
transposed_tensor_layout_tag, acbd, adbc);
const bool plain_transposed_matched
= memory_desc_matches_tag(B_md, plain_tensor_layout_tag)
&& memory_desc_matches_tag(
B_md, transposed_tensor_layout_tag);
if (bgmmc.wei_tag == format_tag::undef
|| plain_transposed_matched) {
if (gemm_based::check_gemm_input_format(B_md)) {
// Note: Here we batch layout may not be accurately represented
// by the wei_tag string, due to all the permutations of the
// batch. Only the gemm dimensions "n, k" are accurately
// represented in the string representing transposed or not
// TODO: update helper.transB() to handle the case that
// wei tag can be represented in both plain and transposed
bgmmc.wei_tag = helper.transB() == 'N'
|| (plain_transposed_matched
&& B_d.blocking_desc()
.strides[bgmmc.ndims
- 1]
== 1)
? plain_tensor_layout_tag
: transposed_tensor_layout_tag;
}
}
}
VCONDCHECK_BG(
format_tag::undef != bgmmc.wei_tag, VERBOSE_UNSUPPORTED_TAG)
}
return status::success;
}
status_t brgemm_matmul_conf_utils_t::update_and_check_B_tag(memory_desc_t &B_md,
int n_blk_size, const matmul_helper_t &helper) const {
if (n_blk_fixed && n_blk_size != bgmmc.wei_n_blk)
return status::unimplemented;
if (!(B_any_layout && blocked_B_layouts_allowed)) return status::success;
return set_or_check_B_tag(B_md, helper, false);
}
status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md,
memory_desc_t &C_md, memory_desc_t &bias_md,
const matmul_helper_t &helper) const {
if (A_any_layout) {
const format_tag_t desired_A_tag = bgmmc.is_gemv
? get_gemv_A_tag(A_md)
: plain_tensor_layout_tag;
VCHECK_BG(memory_desc_init_by_tag(A_md, desired_A_tag),
VERBOSE_UNSUPPORTED_TAG);
bgmmc.src_tag = desired_A_tag;
assert(IMPLICATION(bgmmc.is_gemv, bgmmc.src_tag != format_tag::undef)
&& "if bgmmc.is_gemv is true the format tag must be defined");
} else {
if (bgmmc.is_gemv) {
bgmmc.src_tag = get_gemv_A_tag(A_md);
assert(bgmmc.src_tag != format_tag::undef
&& "if bgmmc.is_gemv is true the format tag must be defined");
} else {
const bool xf16_avx2_vnni_2 = (this->is_bf16() || this->is_f16())
&& bgmmc.isa == avx2_vnni_2;
const bool is_int8_avx512_core
= this->is_int8() && is_superset(bgmmc.isa, avx512_core);
const bool is_adbc_allowed
= (this->is_bf16() || this->is_f32() || this->is_bf32()
|| this->is_f16() || this->is_f32_f16()
|| this->is_f32_bf16()
|| this->is_bf16_with_int_wei()
|| this->is_f16_with_int_wei() || this->is_tf32())
&& !xf16_avx2_vnni_2;
bgmmc.src_tag = is_adbc_allowed ? memory_desc_matches_one_of_tag(
A_md, plain_tensor_layout_tag,
transposed_tensor_layout_tag, acbd, adbc)
: is_int8_avx512_core
? memory_desc_matches_one_of_tag(A_md,
plain_tensor_layout_tag,
transposed_tensor_layout_tag, acbd)
: memory_desc_matches_one_of_tag(
A_md, plain_tensor_layout_tag, acbd);
if (bgmmc.src_tag == format_tag::undef
|| (memory_desc_matches_tag(
A_md, transposed_tensor_layout_tag)
&& memory_desc_matches_tag(
A_md, plain_tensor_layout_tag)
&& IMPLICATION(
!is_adbc_allowed, is_int8_avx512_core))) {
if (gemm_based::check_gemm_input_format(A_md)) {
// Note: Here we batch layout may not be accurately represented
// by the wei_tag string, due to all the permutations of the
// batch. Only the gemm dimensions "m, k" are accurately
// represented in the string representing transposed or not.
bgmmc.src_tag = helper.transA() == 'N'
? plain_tensor_layout_tag
: transposed_tensor_layout_tag;
}
if (!IMPLICATION(bgmmc.src_tag == transposed_tensor_layout_tag,
is_adbc_allowed || is_int8_avx512_core
|| bgmmc.is_gemv))
bgmmc.src_tag = format_tag::undef;
}
}
}
if (C_any_layout) {
const format_tag_t desired_C_tag = plain_tensor_layout_tag;
VCHECK_BG(memory_desc_init_by_tag(C_md, desired_C_tag),
VERBOSE_UNSUPPORTED_TAG);
bgmmc.dst_tag = desired_C_tag;
} else {
const memory_desc_wrapper C_mdw(C_md);
// If one of dims is `1` then `ba` is identical to `ab`.
format_tag_t allowed_transposed_tensor_layout_tag
= C_mdw.ndims() == 2 && C_mdw.count_non_unit_dims(1)
? ba
: plain_tensor_layout_tag;
bgmmc.dst_tag
= memory_desc_matches_one_of_tag(C_md, plain_tensor_layout_tag,
allowed_transposed_tensor_layout_tag, acbd);
if (bgmmc.dst_tag == format_tag::undef) {
if (gemm_based::check_gemm_output_format(C_md)) {
// Note: Here we batch layout may not be accurately represented
// by the wei_tag string, due to all the permutations of the
// batch. Only the gemm dimensions "m, n" are accurately
// represented in the string.
bgmmc.dst_tag = plain_tensor_layout_tag;
}
}
}
VCONDCHECK_BG(!one_of(format_tag::undef, bgmmc.src_tag, bgmmc.dst_tag),
VERBOSE_UNSUPPORTED_TAG)
if (bgmmc.with_bias && bias_any_layout)
VCHECK_BG(memory_desc_init_by_tag(bias_md, plain_tensor_layout_tag),
VERBOSE_UNSUPPORTED_TAG);
return status::success;
}
status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const {
memory_desc_t want_B_md = B_md;
// Set bits for all dimensions except k dimension
const int compensation_mask
= ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2)));
if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) {
want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8;
want_B_md.extra.compensation_mask = compensation_mask;
}
if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) {
want_B_md.extra.flags
|= memory_extra_flags::compensation_conv_asymmetric_src;
want_B_md.extra.asymm_compensation_mask = compensation_mask;
}
if (B_any_layout) {
B_md = want_B_md;
return status::success;
}
return B_md == want_B_md ? status::success : status::unimplemented;
}
format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout(
int n_blk) const {
if (bgmmc.ndims > 3) return format_tag::undef;
if (is_int8() || is_f8()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a;
case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a;
case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a;
default: return format_tag::undef;
}
}
const bool is_amx_or_avx2_vnni_2 = is_superset(bgmmc.isa, avx512_core_amx)
|| is_superset(bgmmc.isa, avx2_vnni_2);
const bool prefer_amx_or_avx2_vnni_2 = is_f16() || is_f32_f16()
|| is_f32_bf16() || is_f16_with_int_wei();
if ((prefer_amx_or_avx2_vnni_2 && is_amx_or_avx2_vnni_2) || is_bf16()
|| is_bf16_with_int_wei()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a;
case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a;
case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a;
case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a;
default: return format_tag::undef;
}
}
// Note: bf32 assumes f32 blocking
if (is_f32() || is_bf32() || is_f16() || is_f32_f16() || is_f32_bf16()
|| is_f16_with_int_wei() || is_tf32()) {
switch (n_blk) {
case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b;
case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b;
case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b;
case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b;
default: return format_tag::undef;
}
}
return format_tag::undef;
}
brgemm_broadcast_t get_zp_type(const primitive_attr_t &attr, int arg) {
return attr.zero_points_.has_default_values(arg)
? brgemm_broadcast_t::none
: brgemm_broadcast_t::per_tensor;
}
struct matmul_avx512_blocking_params_t {
struct matmul_params_t {
matmul_params_t(int m, int n, int k, int od)
: M(m), N(n), K(k), batch(od) {}
const int M;
const int N;
const int K;
const int batch;
};
matmul_avx512_blocking_params_t(const matmul_params_t &m, const int nthr)
: mp(m)
, m_chunks(1)
, m_blk(1)
, m_tail(0)
, n_chunks(1)
, n_blk(1)
, n_tail(0)
, batch_size(1)
, k_blk(1)
, k_tail(0)
, nthr_k(1)
, nthr(nthr) {}
matmul_avx512_blocking_params_t &operator=(
const matmul_avx512_blocking_params_t &brgemm_params) {
m_chunks = brgemm_params.m_chunks;
m_blk = brgemm_params.m_blk;
m_tail = brgemm_params.m_tail;
n_chunks = brgemm_params.n_chunks;
n_blk = brgemm_params.n_blk;
n_tail = brgemm_params.n_tail;
batch_size = brgemm_params.batch_size;
k_blk = brgemm_params.k_blk;
k_tail = brgemm_params.k_tail;
nthr_k = brgemm_params.nthr_k;
return *this;
}
const matmul_params_t &mp;
int m_chunks, m_blk, m_tail;
int n_chunks, n_blk, n_tail;
int batch_size, k_blk, k_tail;
int nthr_k;
const int nthr;
void update_params(int m_chunks_, int m_blk_, int n_chunks_, int n_blk_,
int batch_size_, int k_blk_, int nthr_k_) {
m_chunks = m_chunks_;
m_blk = m_blk_;
m_tail = mp.M % m_blk;
n_chunks = n_chunks_;
n_blk = n_blk_;
n_tail = mp.N % n_blk;
batch_size = batch_size_;
k_blk = k_blk_;
k_tail = mp.K % k_blk;
nthr_k = nthr_k_;
}
float calculate_spatial_disbalance(size_t work, size_t thread_block) const {
size_t mod = work % thread_block;
size_t scalar = work < thread_block
? thread_block - mod
: nstl::min(thread_block - mod, mod);
return static_cast<float>(scalar) / thread_block;
}
float get_imbalance() const {
const size_t cur_nthr = nthr / nthr_k;
size_t parallel_work = get_parallel_work();
const float parallel_work_disb
= calculate_spatial_disbalance(parallel_work, cur_nthr);
int m_work = (m_blk * div_up(mp.M, m_blk)) % mp.M;
const float m_blk_disbalance = static_cast<float>(m_work) / mp.M;
int num_n_blk = div_up(mp.N, n_blk);
int par_n_chunks = div_up(num_n_blk, n_chunks);
const float n_chunk_disbalance
= (static_cast<float>(par_n_chunks) * n_chunks - num_n_blk)
/ num_n_blk;
const float disbalance_nthr_k
= calculate_spatial_disbalance(mp.K, nthr_k * k_blk);
const float thread_allocation_disb
= (cur_nthr * nthr_k) != static_cast<size_t>(nthr)
? (static_cast<float>(nthr) - cur_nthr * nthr_k) / nthr
: 0;
const float score
= (parallel_work_disb + m_blk_disbalance + n_chunk_disbalance
+ thread_allocation_disb + disbalance_nthr_k)
/ 5;
return score;
}
size_t get_parallel_work() const {
int m_elems = div_up(mp.M, m_blk * m_chunks);
int n_elems = div_up(mp.N, n_blk * n_chunks);
return static_cast<size_t>(m_elems) * n_elems * mp.batch;
}
inline dim_t get_actual_lda(bool use_buffer_a, dim_t a_dt_sz) const {
if (!use_buffer_a) return mp.K;
constexpr int bytes_in_cacheline = 64;
const int elems_in_cacheline = bytes_in_cacheline / a_dt_sz;
dim_t lda = rnd_up(k_blk, elems_in_cacheline);
const bool is_big_pow_2 = lda >= 512 && math::is_pow2(lda);
if (is_big_pow_2) lda += elems_in_cacheline;
return lda;
}
inline bool is_buffer_c_required(const brgemm_matmul_conf_t &bgmmc) const {
const size_t k_chunk_elems = bgmmc.K_chunk_elems;
if (nthr_k > 1 && static_cast<size_t>(mp.K) > k_chunk_elems)
return true;
return ((bgmmc.acc_dt != bgmmc.dst_dt || bgmmc.with_sum)
&& (static_cast<size_t>(mp.K) > k_chunk_elems
|| mp.K % k_blk > 0));
}
void update_configuration(brgemm_matmul_conf_t &bgmmc) const {
bgmmc.M_blk = m_blk;
bgmmc.M_chunk_size = m_chunks;
bgmmc.N_blk = n_blk;
bgmmc.N_chunk_size = n_chunks;
bgmmc.K_blk = rnd_up(k_blk, bgmmc.required_k_granularity);
bgmmc.brgemm_batch_size = batch_size;
bgmmc.nthr_k = nthr_k;
bgmmc.use_buffer_c = is_buffer_c_required(bgmmc);
bgmmc.LDA = bgmmc.adjust_a_strides || bgmmc.use_buffer_a
|| bgmmc.treat_A_as_plain
? get_actual_lda(bgmmc.use_buffer_a, bgmmc.tr_a_dt_sz)
: bgmmc.A_strides[1] / bgmmc.a_dt_sz;
}
};
float compute_blocking_heuristic_avx512(brgemm_matmul_conf_t &bgmmc,
const brgemm_matmul_conf_utils_t &bm_conf_utils,
const matmul_avx512_blocking_params_t::matmul_params_t &matmul,
matmul_avx512_blocking_params_t &best_blocking) {
const int nthr = bgmmc.nthr;
const int max_m_blk = nstl::min(256, matmul.M);
int min_m_blk = nstl::min(32, matmul.M);
dim_t min_m_chunks = div_up(matmul.M, max_m_blk);
int n_blk = bgmmc.N_blk;
const int n_chunks = div_up(matmul.N, n_blk);
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);
// Note: do not extend K_blk for 'bwd_w' cases
const bool use_extended_k_blk = matmul.K > 1024
&& (!bm_conf_utils.check_is_transposed(bgmmc.src_tag));
int default_k_blk = use_extended_k_blk ? 1024 : 512;
int k_blk = nstl::min(matmul.K, default_k_blk);
int start_nthr_k = 1;
int last_nthr_k = 1;
// for cases with low parallel work, reduce 'min_m_blk' to
// increase potential parallelization balance.
const dim_t max_parallel = static_cast<dim_t>(matmul.batch) * n_chunks;
const dim_t max_bmn_parallel = max_parallel * min_m_chunks;
const bool low_parallel_work = nthr > max_parallel;
if (low_parallel_work) {
min_m_blk = nstl::min(matmul.M, 16);
// 2nd level tuning for low parallel work cases:
bool bwd_w_low_spatial_work
= bm_conf_utils.check_is_transposed(bgmmc.src_tag)
&& matmul.M <= 512;
bool low_spatial_work = matmul.M <= 40;
if (low_spatial_work || bwd_w_low_spatial_work) {
// Reduce n_blk size to increase parallel space
// note: over reduction of n_blk size on 2d shapes when n_chunks == 1
// showed significant performance degradation
if (!bm_conf_utils.check_n_blk_fixed()
&& IMPLICATION(n_chunks == 1, bgmmc.batch_ndims > 0))
n_blk = nstl::min(matmul.N, 32);
// force to plain B (wei) in small spatial size for FWD:
// note: this showed significant performance gain in WnD shapes
bool is_FWD = !(bm_conf_utils.check_is_transposed(bgmmc.wei_tag)
|| bm_conf_utils.check_is_transposed(bgmmc.src_tag));
if (bgmmc.use_buffer_b && is_FWD) {
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(false);
}
}
// Parallelize across K for shapes with big 'K' dimension
bool bwd_w_par_k_blk = bgmmc.batch == 1
&& bm_conf_utils.check_is_transposed(bgmmc.src_tag)
&& !bm_conf_utils.is_int8()
&& IMPLICATION(bm_conf_utils.is_bf16(), math::is_pow2(matmul.K))
&& matmul.K >= 2048;
if (bwd_w_par_k_blk) {
start_nthr_k = nstl::min(nthr, 4);
assert(k_blk == nstl::min(matmul.K, 512));
}
// Enable k-partitioning for huge k and small m/n dimensions.
bool is_huge_k = matmul.K >= 20000;
bool is_small_mn = matmul.M <= 512 && matmul.N <= 512;
bool use_k_partitioning = is_huge_k && is_small_mn;
// TODO: expand to other data types.
use_k_partitioning = use_k_partitioning && bm_conf_utils.is_f32();
if (use_k_partitioning) {
auto least_prime_factor = [](int n) {
assert(n > 0);
if (n == 1) return 1;
for (int factor = 2; factor < n; factor++)
if (n % factor == 0) return factor;
return n;
};
int nthr_bmn = max_div(max_bmn_parallel, nthr);
int nthr_k = nstl::max(nthr / nthr_bmn, 1);
int nthr_remainder = nthr % nthr_bmn;
// Choose number of threads in k-dim to allow a larger block size
// for m-dim.
while (nthr_k <= nthr_remainder) {
nthr_bmn /= least_prime_factor(nthr_bmn);
nthr_k = nstl::max(nthr / nthr_bmn, 1);
nthr_remainder = nthr % nthr_bmn;
}
// Reduce number of threads in k-dim to balanced work.
dim_t k_chunks = div_up(matmul.K, k_blk);
while (k_chunks <= 5 * nthr_k && nthr_k > 1)
nthr_k /= least_prime_factor(nthr_k);
// Fix number of threads for k-dim.
start_nthr_k = nthr_k;
last_nthr_k = nthr_k;
}
}
// Use large m-blocking if possible.
const bool is_huge_n = matmul.N >= 20000;
const bool large_bmn_parallelism = max_bmn_parallel > 10 * nthr;
const bool has_m_tail = matmul.M % max_m_blk != 0;
const bool use_k_partitioning = start_nthr_k > 1;
bool use_large_m_blk = is_huge_n && large_bmn_parallelism && !has_m_tail;
use_large_m_blk &= !use_k_partitioning;
// TODO: Expand to other data types.
use_large_m_blk = use_large_m_blk && bm_conf_utils.is_f32();
if (use_large_m_blk) min_m_blk = max_m_blk;
matmul_avx512_blocking_params_t cur_params(matmul, nthr);
float best_imbalance = 1.f; // reduce
for (int nthr_k = start_nthr_k; nthr_k >= last_nthr_k; --nthr_k) {
bool found_best_blocking = false;
for_(int n_chunk_size = n_chunks_start; n_chunk_size >= 1;
--n_chunk_size)
for (int m_blk = max_m_blk; m_blk >= min_m_blk; --m_blk) {
cur_params.update_params(
1, m_blk, n_chunk_size, n_blk, 1, k_blk, nthr_k);
float cur_imbalance = cur_params.get_imbalance();
const int m_chunk_size = 1;
int m_chunks = div_up(bgmmc.M, m_blk * m_chunk_size);
int n_chunks = div_up(bgmmc.N, n_blk * n_chunk_size);
int work_amount = bgmmc.batch * m_chunks * n_chunks;
int nthr_bmn = nthr / nthr_k;
bool skip_config = work_amount < nthr_bmn * 3
&& work_amount % nthr_bmn != 0 && start_nthr_k == 1;
if (skip_config) continue;
if (cur_imbalance < best_imbalance) {
best_imbalance = cur_imbalance;
best_blocking = cur_params;
found_best_blocking = true;
}
}
if (!found_best_blocking) {
cur_params.update_params(1, min_m_blk, 1, n_blk, 1, k_blk, nthr_k);
float cur_imbalance = cur_params.get_imbalance();
if (cur_imbalance < best_imbalance) {
best_imbalance = cur_imbalance;
best_blocking = cur_params;
}
}
}
return best_imbalance;
}
float compute_blocking_heuristic_avx2(brgemm_matmul_conf_t &bgmmc,
const brgemm_matmul_conf_utils_t &bm_conf_utils,
const matmul_avx512_blocking_params_t::matmul_params_t &matmul,
matmul_avx512_blocking_params_t &best_blocking) {
const int nthr = bgmmc.nthr;
const int max_m_blk = nstl::min(/*64*/ 256, matmul.M);
int min_m_blk = nstl::min(32, matmul.M); // max_m_blk
int n_blk = bgmmc.N_blk;
const int n_chunks = div_up(matmul.N, n_blk);
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);
int default_k_blk = 1024;
int k_blk = nstl::min(matmul.K, default_k_blk);
int start_nthr_k = 1;
// for cases with low parallel work, reduce 'min_m_blk' to
// increase potential parallelization balance.
const size_t max_parallel = matmul.batch * n_chunks;
const bool low_parallel_work = static_cast<size_t>(nthr) > max_parallel;
if (low_parallel_work) {
min_m_blk = nstl::min(matmul.M, 16);
bool low_spatial_work = matmul.M <= 40;
if (low_spatial_work) {
// Reduce n_blk size to increase parallel space
// note: over reduction of n_blk size on 2d shapes when n_chunks == 1
// showed significant performance degradation
if (!bm_conf_utils.check_n_blk_fixed()
&& IMPLICATION(n_chunks == 1, bgmmc.batch_ndims > 0))
n_blk = nstl::min(matmul.N, 32);
}
}
float best_imbalance = 1.f; // reduce
for_(int nthr_k = start_nthr_k; nthr_k >= 1; --nthr_k)
for_(int n_chunk_size = n_chunks_start; n_chunk_size >= 1; --n_chunk_size)
for (int m_blk = max_m_blk; m_blk >= min_m_blk; --m_blk) {
matmul_avx512_blocking_params_t cur_params(matmul, nthr);
cur_params.update_params(
1, m_blk, n_chunk_size, n_blk, 1, k_blk, nthr_k);
float cur_imbalance = cur_params.get_imbalance();
if (cur_imbalance < best_imbalance) {
best_imbalance = cur_imbalance;
best_blocking = cur_params;
}
}
return best_imbalance;
}
float compute_blocking_heuristic_avx2_f32(brgemm_matmul_conf_t &bgmmc,
const brgemm_matmul_conf_utils_t &bm_conf_utils,
const matmul_avx512_blocking_params_t::matmul_params_t &matmul,
matmul_avx512_blocking_params_t &best_blocking) {
float best_imbalance = 1.f; // reduce
const int nthr = bgmmc.nthr;
dim_t max_m_blk = nstl::min(256, matmul.M);
dim_t min_m_blk = nstl::min(32, matmul.M);
int n_blk = bgmmc.N_blk;
const int n_chunks = div_up(matmul.N, n_blk);
const int max_n_chunks = bgmmc.use_buffer_a ? 16 : 1;
const int n_chunks_start = nstl::min(max_n_chunks, n_chunks);
int default_k_blk = 1024;
int k_blk = nstl::min(matmul.K, default_k_blk);
int start_nthr_k = 1;
// for cases with low parallel work, reduce 'min_m_blk' to
// increase potential parallelization balance.
size_t max_parallel = matmul.batch * n_chunks;
const float req_additional_parallel = nthr / max_parallel;
if (req_additional_parallel > 1) {
min_m_blk = saturate<int>(
16, max_m_blk, matmul.M / req_additional_parallel);
max_parallel *= div_up(matmul.M, min_m_blk);
} else if (bm_conf_utils.check_is_transposed(bgmmc.src_tag)
&& matmul.K >= 4096) {
min_m_blk = nstl::max(16, matmul.M / 4);
default_k_blk = 192;
}
bool low_parallel_work = max_parallel % nthr != 0
&& (static_cast<float>(max_parallel) / nthr) < 2;
if (low_parallel_work) {
// Reduce n_blk size to increase parallel space
// note: over reduction of n_blk size on 2d shapes when n_chunks == 1
// showed significant performance degradation
if (!bm_conf_utils.check_n_blk_fixed()
&& IMPLICATION(n_chunks == 1, bgmmc.batch_ndims > 0)) {
n_blk = nstl::min(matmul.N, 16);
}
}
max_m_blk = nstl::max(max_m_blk, min_m_blk);
for_(int nthr_k = start_nthr_k; nthr_k >= 1; --nthr_k)
for_(int n_chunk_size = n_chunks_start; n_chunk_size >= 1; --n_chunk_size)
for (int m_blk = max_m_blk; m_blk >= min_m_blk; --m_blk) {
matmul_avx512_blocking_params_t cur_params(matmul, nthr);
cur_params.update_params(
1, m_blk, n_chunk_size, n_blk, 1, k_blk, nthr_k);
float cur_imbalance = cur_params.get_imbalance();
if (cur_imbalance < best_imbalance) {
best_imbalance = cur_imbalance;
best_blocking = cur_params;
}
}
return best_imbalance;
}
status_t compute_blocking_heuristic(brgemm_matmul_conf_t &bgmmc,
const brgemm_matmul_conf_utils_t &bm_conf_utils) {
bgmmc.N_blk = bgmmc.wei_n_blk;
if (!bgmmc.is_runtime_N) bgmmc.N_blk = nstl::min(bgmmc.N_blk, bgmmc.N);
bgmmc.M_chunk_size = bgmmc.N_chunk_size = bgmmc.K_chunk_size = 1;
bool prefer_copy_a
= one_of(true, bm_conf_utils.is_f32() && bgmmc.isa == avx2,
bm_conf_utils.is_bf16(),
bm_conf_utils.is_bf16_with_int_wei(),
(bgmmc.is_amx
&& (bm_conf_utils.is_f16()
|| bm_conf_utils.is_f16_with_int_wei())))
&& (bgmmc.isa != avx2_vnni_2) // no perf study yet.
&& bgmmc.lda_big_pow2() && bgmmc.M >= 1024;
// Avoid copying A for small N gives better performance.
// TODO: Expand for other precisions and cases.
if (bgmmc.is_amx && bm_conf_utils.is_int8())
prefer_copy_a &= bgmmc.N >= 256;
if (bgmmc.is_amx) {
if (matmul_amx_blocking_params_macro_t::is_supported(
bgmmc, bm_conf_utils)) {
//grid heuristic is possible best blocking is set
matmul_amx_blocking_params_macro_t best_blocking(bgmmc);
matmul_amx_blocking_params_macro_t::find_best_blocking(
bgmmc, bm_conf_utils, best_blocking);
if (best_blocking.get_blocking_scores() != 0.0f) {
best_blocking.update_configuration(bgmmc);
return status::success;
}
}
bgmmc.use_buffer_a |= prefer_copy_a;
// Configure matrix sizes
if (bgmmc.is_runtime_M) {
bgmmc.M_blk = 64; // use fixed block size for runtime M case
} else {
auto get_block_candidate = [&]() -> dim_t {
// for AMX prefer block sizes which utilize at least 13 tile
// rows
const dim_t tile_rows_min = 13;
const dim_t tile_rows_max = 16;
const dim_t scale_rows_min = 2;
const dim_t scale_rows_max = 4;
for_(dim_t r = tile_rows_max; r >= tile_rows_min; r--)
for (dim_t s = scale_rows_max; s >= scale_rows_min; s--) {
const dim_t m_blk = s * r;
if (bgmmc.M % m_blk == 0) return m_blk;
}
const dim_t max_M = scale_rows_max * tile_rows_max;
return nstl::min(bgmmc.M, max_M);
};
bgmmc.M_blk = get_block_candidate();
}
// AMX BRGEMM kernel requires (K_brgemm % 64 == 0 || K_brgemm < 64)
// for K_brgemm reduction value to avoid AMX tiles re-configuration.
// To satisfy this condition K_tail value is fixed to K % wei_k_blk here.
const bool fixed_K_tail_size
= bgmmc.K % bgmmc.wei_k_blk > 0 && bgmmc.K > bgmmc.wei_k_blk;
bgmmc.K_blk = bgmmc.K < bgmmc.wei_k_blk
? rnd_up(bgmmc.K, bgmmc.required_k_granularity)
: fixed_K_tail_size ? bgmmc.wei_k_blk
: bgmmc.K;
bgmmc.brgemm_batch_size
= nstl::max(bgmmc.K / bgmmc.K_blk, static_cast<dim_t>(1));
matmul_amx_blocking_params_micro_t best_blocking(bgmmc);
matmul_amx_blocking_params_micro_t::find_best_blocking(
bgmmc, bm_conf_utils, best_blocking);
VCONDCHECK_BG(best_blocking.get_blocking_scores() != 0.0f,
VERBOSE_BLOCKING_FAIL, "");
best_blocking.update_configuration(bgmmc);
} else if (is_superset(bm_conf_utils.get_isa(), avx512_core)) {
// TODO:
// *) adjust K_BLK using 'rnd_up(bgmmc.K, bgmmc.required_k_granularity)'
// for non-f32 datatypes.
// *) optimize param search complexity
// Approach for selecting ideal 'blocking parameters':
// M_blk:
// - main param for having parallel_work optimally distributed.
// - 'br_block' is a BRGeMM uKernel parameter derived from 'M_Blk',
// however, there is no measured performance impact from small
// variations in 'br_block' size.
//
// M_Chunks:
// - no noticeable performance impact i.e. 'M_blk = M_Chunks * M_Blk';
// with M_Chunks > 1', brgemm has the same performance results. Instead,
// choose a larger 'M_blk'.
//
// N_blk:
// - ideally 64 (from 'get_default_n_block()').
// - can be reduced to 32 to improve performance for some shapes, as
// well as increasing parallelization search space.
//
// N_Chunks:
// - No different as long as thread/work balance is the same.
// - Note: for A_Transposed cases using A_buffer (i.e. bwd-w): select
// a higher count to increase performance -better for transposed data
// reuse.
//
// K_blk:
// - block size variation '512 <= K_blk < 1024' has negligible
// performance difference. However, Some cases benefit from higher
// block size.
// - can parallelize if not enough work; notice: requires reduction!
//
// Batch_Size:
// - unused.
bgmmc.use_buffer_a |= prefer_copy_a;
const matmul_avx512_blocking_params_t::matmul_params_t matmul(
bgmmc.M, bgmmc.N, bgmmc.K, bgmmc.batch);
matmul_avx512_blocking_params_t best_blocking(matmul, bgmmc.nthr);
const float best_imbalance = compute_blocking_heuristic_avx512(
bgmmc, bm_conf_utils, matmul, best_blocking);
VCONDCHECK_BG(best_imbalance != 1.f, VERBOSE_BLOCKING_FAIL, "")
best_blocking.update_configuration(bgmmc);
} else {
bgmmc.use_buffer_a |= prefer_copy_a;
VCONDCHECK_BG(is_superset(bm_conf_utils.get_isa(), avx2),
VERBOSE_UNSUPPORTED_ISA)
const bool is_f32 = bm_conf_utils.is_f32() && bgmmc.isa == avx2;
const matmul_avx512_blocking_params_t::matmul_params_t matmul(
bgmmc.M, bgmmc.N, bgmmc.K, bgmmc.batch);
matmul_avx512_blocking_params_t best_blocking(matmul, bgmmc.nthr);
const float best_imbalance = is_f32
? compute_blocking_heuristic_avx2_f32(
bgmmc, bm_conf_utils, matmul, best_blocking)
: compute_blocking_heuristic_avx2(
bgmmc, bm_conf_utils, matmul, best_blocking);
VCONDCHECK_BG(best_imbalance != 1.f, VERBOSE_BLOCKING_FAIL, "")
best_blocking.update_configuration(bgmmc);
}
return status::success;
}
status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
const matmul_desc_t &mmd, memory_desc_t &src_md,
memory_desc_t &weights_md, memory_desc_t &dst_md,
memory_desc_t &bias_md, primitive_attr_t &attr,
const std::function<bool()> &can_use_gemm_fallback) {
const memory_desc_wrapper src_d(&src_md);
const memory_desc_wrapper weights_d(&weights_md);
const memory_desc_wrapper dst_d(&dst_md);
bgmmc = zero<decltype(bgmmc)>();
bgmmc.isa = isa;
bgmmc.nthr = dnnl_get_max_threads();
bgmmc.brg_type = brgemm_addr;
bgmmc.src_dt = src_d.data_type();
bgmmc.orig_src_dt = src_d.data_type();
bgmmc.dst_dt = dst_d.data_type();
bgmmc.wei_dt = weights_d.data_type();
bgmmc.orig_wei_dt = weights_d.data_type();
bgmmc.with_reduce = mmd.reduce_desc.format_kind != format_kind::undef;
bgmmc.reduce_dt
= bgmmc.with_reduce ? mmd.reduce_desc.data_type : data_type::undef;
bgmmc.reduce_kind = mmd.reduce_kind;
bgmmc.with_bias = mmd.bias_desc.format_kind != format_kind::undef;
bgmmc.bia_dt = bgmmc.with_bias ? mmd.bias_desc.data_type : data_type::undef;
bgmmc.s8s8_compensation_required = bgmmc.src_dt == s8 && !isa_has_s8s8(isa);
bgmmc.ndims = dst_d.ndims();
bgmmc.is_thread_chunks_exec_order_horizontal = true;
bgmmc.mem_advice
= brgemm_kernel_hint_mem_advice_t::brgemm_hint_mem_advice_undef;
const bool is_wei_any = weights_d.format_kind() == format_kind::any
|| weights_d.is_sparse_packed_desc();
brgemm_matmul_conf_utils_t bm_conf_utils(bgmmc, isa, attr,
src_d.format_kind() == format_kind::any, is_wei_any,
dst_d.format_kind() == format_kind::any,
bias_md.format_kind == format_kind::any);
VCHECK_BG(check_datatype_cfg(bm_conf_utils), VERBOSE_UNSUPPORTED_DT_CFG);
VCHECK_BG(check_isa_with_datatype(isa, bm_conf_utils),
VERBOSE_ISA_DT_MISMATCH);
bgmmc.is_amx = is_superset(isa, avx512_core_amx);
bgmmc.a_dt_sz = bgmmc.tr_a_dt_sz = types::data_type_size(bgmmc.src_dt);
bgmmc.b_dt_sz = bgmmc.tr_b_dt_sz = types::data_type_size(bgmmc.wei_dt);
bgmmc.packed_sparse_weights = weights_d.is_sparse_packed_desc();
if (bgmmc.packed_sparse_weights) {
VCONDCHECK_BG(bgmmc.is_amx, VERBOSE_ISA_SPARSE_ENCODING_MISMATCH);
VCONDCHECK_BG(bgmmc.wei_dt == s8, VERBOSE_UNSUPPORTED_DT);
}
bgmmc.is_bf32 = bm_conf_utils.is_bf32();
bgmmc.is_tf32 = bm_conf_utils.is_tf32();
bgmmc.is_bf16_with_int_wei = bm_conf_utils.is_bf16_with_int_wei();
bgmmc.is_f16_with_int_wei = bm_conf_utils.is_f16_with_int_wei();
bgmmc.is_f32_f16 = bm_conf_utils.is_f32_f16();
bgmmc.is_f32_bf16 = bm_conf_utils.is_f32_bf16();
bgmmc.with_wei_decompression = bm_conf_utils.with_weights_decompression();
bgmmc.is_int4_weights = one_of(bgmmc.wei_dt, data_type::s4, data_type::u4);
bgmmc.is_f4_via_convert = bm_conf_utils.is_f4_via_convert();
if (bgmmc.is_f4_via_convert) {
bgmmc.wei_dt = f32;
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
}
// Make BRGeMM compute MatMul as if it were in bfloat16, while down-convert
// happens during copy-buffer computations
if (bgmmc.is_bf32 || bgmmc.is_bf16_with_int_wei) {
bgmmc.src_dt = bf16;
bgmmc.wei_dt = bf16;
bgmmc.tr_a_dt_sz = types::data_type_size(bf16);
bgmmc.tr_b_dt_sz = types::data_type_size(bf16);
} else if ((bm_conf_utils.is_f16() || bgmmc.is_f16_with_int_wei)
&& bgmmc.isa == avx512_core_fp16) {
// Similar to bf32, convert input data before compute
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
} else if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2)) {
// Note 1: Keep this branch separately from f16 one to have different
// ISA conditions (f16 includes f16:f32 and f16:f16 combinations). Same
// applies for bf16 (which includes bf16:bf16).
// Note 2: If `use_buffer_b()` is false, let the kernel perform the
// conversion. Otherwise, make the copy_b routine handle the conversion
// and set kernel data types to f32.
// Note 3: Since `use_buffer_b()` depends on `bgmmc.wei_tag`, which is
// set later in the code due to its dependencies, the update of data
// types to f32 happens below in ANCHOR: `CONVERT_F32_XF16_DATA_TYPES`.
} else if (bgmmc.is_f16_with_int_wei && bgmmc.isa != avx512_core_fp16) {
bgmmc.src_dt = f16;
bgmmc.wei_dt = f16;
bgmmc.tr_a_dt_sz = types::data_type_size(f16);
bgmmc.tr_b_dt_sz = types::data_type_size(f16);
}
bgmmc.acc_dt = bm_conf_utils.is_int8() ? s32 : f32;
bgmmc.c_dt_sz = types::data_type_size(bgmmc.dst_dt);
bgmmc.acc_dt_sz = types::data_type_size(bgmmc.acc_dt);
if (bgmmc.with_bias) bgmmc.bias_dt_sz = types::data_type_size(bgmmc.bia_dt);
if (bgmmc.with_reduce)
bgmmc.reduce_dt_sz = types::data_type_size(bgmmc.reduce_dt);
const auto &src_scales = attr.scales_.get(DNNL_ARG_SRC);
const auto &wei_scales = attr.scales_.get(DNNL_ARG_WEIGHTS);
bgmmc.with_src_scales = !src_scales.has_default_values();
bgmmc.with_wei_scales = !wei_scales.has_default_values();
if (bgmmc.with_wei_scales) {
const auto wei_qmask_N = 1 << (bgmmc.ndims - 1);
const auto wei_qmask_K = 1 << (bgmmc.ndims - 2);
bgmmc.is_wei_scale_per_k = wei_scales.get_mask() & wei_qmask_K;
bgmmc.is_wei_scale_per_n = wei_scales.get_mask() & wei_qmask_N;
bgmmc.apply_scales_in_buffer_b = bgmmc.is_wei_scale_per_k
&& bgmmc.with_wei_decompression && bgmmc.N * bgmmc.K != 1;
bgmmc.wei_scales_dt = wei_scales.get_data_type();
bgmmc.wei_scales_dt_sz = types::data_type_size(bgmmc.wei_scales_dt);
bgmmc.wei_scales_k_group_size = wei_scales.get_group(0);
// only common and per-oc-channel scales are supported
// only per-ic-channel scales is supprted with weight decompression
VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_wei_scale_per_n
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
bgmmc.with_wei_decompression),
VERBOSE_UNSUPPORTED_SCALES_CFG);
}
const auto &dst_scales = attr.scales_.get(DNNL_ARG_DST);
bgmmc.with_dst_scales = !dst_scales.has_default_values();
// only common scales are supported
VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
const auto &p = attr.post_ops_;
bgmmc.with_sum = p.find(primitive_kind::sum) != -1;
const int eltwise_ind = p.find(primitive_kind::eltwise);
bgmmc.with_eltwise = eltwise_ind != -1;
const int binary_ind = p.find(primitive_kind::binary);
const int prelu_ind = p.find(primitive_kind::prelu);
bgmmc.with_binary = !everyone_is(-1, binary_ind, prelu_ind);
bgmmc.src_zp_type = get_zp_type(attr, DNNL_ARG_SRC);
bgmmc.wei_zp_type = get_zp_type(attr, DNNL_ARG_WEIGHTS);
bgmmc.dst_zp_type = get_zp_type(attr, DNNL_ARG_DST);
VCONDCHECK_BG(
IMPLICATION(!(bm_conf_utils.is_int8()
|| bm_conf_utils.with_weights_decompression()),
everyone_is(brgemm_broadcast_t::none, bgmmc.src_zp_type,
bgmmc.wei_zp_type, bgmmc.dst_zp_type)),
VERBOSE_UNSUPPORTED_ZP_CFG);
matmul_helper_t helper(src_d, weights_d, dst_d);
bgmmc.batch_ndims = bgmmc.ndims - 2;
bgmmc.M = helper.M();
bgmmc.N = helper.N();
bgmmc.K = helper.K();
bgmmc.batch = helper.batch();
bgmmc.is_runtime_M = is_runtime_value(bgmmc.M);
bgmmc.is_runtime_N = is_runtime_value(bgmmc.N);
bgmmc.is_runtime_K = is_runtime_value(bgmmc.K);
bgmmc.is_gemv
= is_gemv_applicable(bgmmc, bm_conf_utils, src_md, weights_md);
if (!bgmmc.is_gemv && bm_conf_utils.is_f32() && bgmmc.isa == avx2
&& (bgmmc.N == 1 || bgmmc.M == 1)) {
// The brgemm matmul implementation for avx2 and f32 data type has
// some performance gaps compared to the autogenerated GEMM
// implementation in some N=1 and M=1 cases.
//
// The implementation has a dedicated code path optimized for some
// N=1 cases (bgmmc.is_gemv = true), which guarantees performance
// on par with or better than the GEMM implementation for applicable
// GEMV shapes.
//
// However, this dedicated code path does not cover all GEMV scenarios.
// For other GEMV cases, we fall back to the GEMM implementation,
// as it typically offers better performance as of now.
//
// For all non-GEMV cases, we use this brgemm matmul implementation
// by default.
//
// However, we must ensure that GEMM can handle the data formats.
// If it cannot (e.g., the weights format is blocked), we use
// this implementation to avoid falling back to the reference one.
VCONDCHECK_BG(!can_use_gemm_fallback(),
"Fall back to the GEMM implementation for cases not supported "
"by the GEMV code path for the N=1 and M=1 cases.");
}
VCHECK_BG(bm_conf_utils.set_or_check_tags(src_md, dst_md, bias_md, helper),
VERBOSE_UNSUPPORTED_TAG);
VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);
VCONDCHECK_BG(post_ops_ok(bgmmc, attr, dst_d), VERBOSE_UNSUPPORTED_POSTOP);
// runtime values for M/N dimensions are only supported
VCONDCHECK_BG((!(is_runtime_value(bgmmc.batch) || bgmmc.is_runtime_K)),
VERBOSE_RUNTIMEDIM_UNSUPPORTED)
// Single runtime dimension is only supported for now
VCONDCHECK_BG(!(bgmmc.is_runtime_M && bgmmc.is_runtime_N),
VERBOSE_RUNTIMEDIM_UNSUPPORTED)
// Runtime value for M dimension is supported for 2d AMX int8/bfloat16
// problems only.
const bool runtime_M_supported = bgmmc.is_amx && bgmmc.ndims == 2
&& one_of(true, bm_conf_utils.is_int8(), bm_conf_utils.is_bf16());
VCONDCHECK_BG(!(bgmmc.is_runtime_M && !runtime_M_supported),
VERBOSE_RUNTIMEDIM_UNSUPPORTED)
// Runtime N value is supported for 2d AMX int8/bfloat16 problems only.
const bool runtime_N_supported = bgmmc.is_amx && bgmmc.ndims == 2
&& one_of(true, bm_conf_utils.is_int8(), bm_conf_utils.is_bf16());
VCONDCHECK_BG(!(bgmmc.is_runtime_N && !runtime_N_supported),
VERBOSE_RUNTIMEDIM_UNSUPPORTED)
bgmmc.batch_without_first_dim
= bgmmc.batch_ndims > 1 ? helper.batch() / dst_d.dims()[0] : 0;
bgmmc.bcast_A_desc.set_params(
src_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch);
bgmmc.bcast_B_desc.set_params(
weights_d.dims(), dst_d.dims(), bgmmc.batch_ndims, bgmmc.batch);
// required granularity for k dimension
bgmmc.required_k_granularity
= bgmmc.is_amx ? data_type_vnni_granularity(bgmmc.wei_dt) : 1;
VCONDCHECK_BG(bgmmc.required_k_granularity > 0, VERBOSE_BLOCKING_FAIL, "");
bgmmc.wei_k_blk = get_wei_k_blk(bgmmc.wei_dt);
VCHECK_BG(bm_conf_utils.set_or_check_B_tag(weights_md, helper),
VERBOSE_UNSUPPORTED_TAG);
bgmmc.req_wei_vnni_downconvert = bm_conf_utils.wei_down_convert_to_vnni();
VCHECK_BG(attr.set_default_formats(&dst_md), VERBOSE_UNSUPPORTED_TAG);
bgmmc.wei_n_blk = bm_conf_utils.get_default_n_block(bgmmc.wei_tag);
bgmmc.blocked_B = bm_conf_utils.get_blocked_B();
bgmmc.transposed_B = bm_conf_utils.check_is_transposed(bgmmc.wei_tag)
|| bgmmc.wei_tag == adbc;
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
bgmmc.req_transpose_scales = bgmmc.apply_scales_in_buffer_b
&& bgmmc.is_wei_scale_per_k && bgmmc.is_wei_scale_per_n
&& bgmmc.transposed_B;
if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) {
// ANCHOR: `CONVERT_F32_XF16_DATA_TYPES`
bgmmc.src_dt = f32;
bgmmc.wei_dt = f32;
bgmmc.tr_a_dt_sz = types::data_type_size(f32);
bgmmc.tr_b_dt_sz = types::data_type_size(f32);
}
// int4 weights decompression only supports plain and transpose layouts
// TODO: enable int4 reorder and extend support to blocked weights
// layout when needed
if (bgmmc.with_wei_decompression && bgmmc.is_int4_weights)
VCONDCHECK_BG(bm_conf_utils.check_is_plain(bgmmc.wei_tag)
|| bm_conf_utils.check_is_transposed(bgmmc.wei_tag),
VERBOSE_UNSUPPORTED_TAG);
const bool transposed_A = bm_conf_utils.check_is_transposed(bgmmc.src_tag);
// When M == 1 MatMul always considers A to be non-transposed even if A md
// was created using "ba" tag. It is not plain in cab layout.
bgmmc.treat_A_as_plain = bgmmc.M == 1
&& IMPLICATION(bgmmc.batch != 1,
bm_conf_utils.check_is_plain(bgmmc.src_tag));
bgmmc.transposed_A = ((transposed_A && !bgmmc.treat_A_as_plain)
|| bgmmc.src_tag == adbc);
// For batched problems with plain A and C and fully broadcasted across B
// we can merge all the batch dimensions into M if broadcast strategies
// set is limited for binary post-ops
const bool plain_A_layout = bm_conf_utils.check_is_plain(bgmmc.src_tag)
|| bgmmc.treat_A_as_plain;
const bool merge_batch_dims_into_M = bgmmc.batch > 1
&& bgmmc.bcast_B_desc.bcast_across_all_batch_dims && plain_A_layout
&& helper.is_src_dst_layout_batch_fusable()
&& post_ops_ok(
bgmmc, attr, dst_d, true /* limit_bcast_strategies_set */);
if (merge_batch_dims_into_M) {
bgmmc.M *= bgmmc.batch;
bgmmc.batch = 1;
}
// runtime A stride wrt M dimension is not acceptable
VCONDCHECK_BG(!is_runtime_value(helper.get_a_stride(bgmmc.ndims - 2)),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime A stride wrt K dimension is acceptable for transpose A and
// runtime M case only
const bool stride_A_wrt_K_dim_ok = IMPLICATION(
is_runtime_value(helper.get_a_stride(bgmmc.ndims - 1)),
bgmmc.transposed_A && bgmmc.is_runtime_M);
VCONDCHECK_BG(stride_A_wrt_K_dim_ok, VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime A strides wrt batch dimensions are acceptable for runtime M case
// only
for (int b = 0; b < bgmmc.batch_ndims; b++) {
VCONDCHECK_BG(IMPLICATION(is_runtime_value(helper.get_a_stride(b)),
bgmmc.is_runtime_M),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
}
// runtime B stride wrt N dimension is not acceptable
VCONDCHECK_BG(!is_runtime_value(helper.get_b_stride(bgmmc.ndims - 1)),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime B stride wrt K dimension is acceptable for non-transposed B and
// runtime N case only
const bool stride_B_wrt_K_dim_ok = IMPLICATION(
is_runtime_value(helper.get_b_stride(bgmmc.ndims - 2)),
!bgmmc.transposed_B && bgmmc.is_runtime_N);
VCONDCHECK_BG(stride_B_wrt_K_dim_ok, VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime B strides wrt batch dimensions are acceptable for runtime N case
// only
for (int b = 0; b < bgmmc.batch_ndims; b++) {
VCONDCHECK_BG(IMPLICATION(is_runtime_value(helper.get_b_stride(b)),
bgmmc.is_runtime_N),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
}
// runtime C stride wrt N dimension is not acceptable
VCONDCHECK_BG(!is_runtime_value(helper.get_c_stride(bgmmc.ndims - 1)),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime C stride wrt M dimension is acceptable for runtime N case only
const bool stride_C_wrt_K_dim_ok = IMPLICATION(
is_runtime_value(helper.get_c_stride(bgmmc.ndims - 2)),
bgmmc.is_runtime_N);
VCONDCHECK_BG(stride_C_wrt_K_dim_ok, VERBOSE_UNSUPPORTED_MEM_STRIDE);
// runtime C strides wrt batch dimensions are acceptable for runtime N case
// only
for (int b = 0; b < bgmmc.batch_ndims; b++) {
VCONDCHECK_BG(IMPLICATION(is_runtime_value(helper.get_c_stride(b)),
bgmmc.is_runtime_N),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
}
const bool is_copy_a_required = !bgmmc.is_gemv
&& ((bgmmc.is_amx
&& (bm_conf_utils.is_bf32() || bm_conf_utils.is_tf32()))
|| ((bm_conf_utils.is_f16()
|| bm_conf_utils.is_f16_with_int_wei())
&& isa == avx512_core_fp16)
|| (bgmmc.wei_zp_type != brgemm_broadcast_t::none
&& !bm_conf_utils.with_weights_decompression())
|| bgmmc.transposed_A);
bgmmc.use_buffer_a = is_copy_a_required;
// Supported computation with copy only part of A related to K_tail if
// is_copy_a_required == true, but the current performance measurements
// show worse performance for it in comparison with copy whole A approach
// (especially for big K sizes).
bgmmc.use_buffer_a_tail_only = false;
const int dmax = nstl::min(bgmmc.ndims, 3);
for (int d = 0; d < dmax; d++) {
int dim = bgmmc.ndims - 1 - d;
bgmmc.A_strides[d] = bgmmc.a_dt_sz * src_d.blocking_desc().strides[dim];
bgmmc.B_strides[d]
= bgmmc.b_dt_sz * weights_d.blocking_desc().strides[dim];
bgmmc.C_strides[d] = bgmmc.c_dt_sz * dst_d.blocking_desc().strides[dim];
}
// We need to correct A_strides if batched dimensions are merged in M and
// A layout is formally transposed but could be treated as plain
bgmmc.adjust_a_strides = merge_batch_dims_into_M
&& (src_d.matches_tag(acbd) || bgmmc.treat_A_as_plain);
if (bgmmc.adjust_a_strides) bgmmc.A_strides[1] = bgmmc.A_strides[2];
// We need to correct C_strides if batched dimensions are merged in M and
// C layout is formally transposed but could be treated as plain
if (merge_batch_dims_into_M && dst_d.matches_tag(acbd)) {
bgmmc.C_strides[1] = bgmmc.C_strides[2];
}
// BF32 'Hint' Heuristic:
// Under the following conditions, F32 through AVX512_CORE performs better
// than using BF32 arithmetic.
VCONDCHECK_BG(!(bgmmc.is_bf32 && (bgmmc.M < 8)
&& ((bgmmc.wei_tag == abcd)
|| bm_conf_utils.is_any_B_layout())),
VERBOSE_UNSUPPORTED_FPMATH_MODE);
if (matmul_amx_blocking_params_macro_t::is_supported(bgmmc, bm_conf_utils))
if (postops_estimator_t::estimate_insts_per_cacheline(
dst_md, attr, bgmmc.postops_inst_count)
!= status::success) {
// Failed to estimate postops length. Assumption is no impact on
// gemm execution.
bgmmc.postops_inst_count = 0;
}
// Heuristic tries to optimize the following parameters:
// - M_blk, M_Chunk
// - N_blk, N_Chunk
// - K_blk, batch_size
// - nthr_K
VCHECK_BG(compute_blocking_heuristic(bgmmc, bm_conf_utils),
VERBOSE_BLOCKING_FAIL, "");
if (bgmmc.wei_n_blk > bgmmc.N_blk && bgmmc.N != bgmmc.N_blk) {
assert(!bgmmc.is_runtime_N
&& "N_blk should not be adjusted for runtime N");
if (bgmmc.use_buffer_b) {
// Copy kernels require the B buffer to be aligned.
// ZMM registers are used without masking;
// Two YMMs are used in the AVX2 case with the same granularity.
size_t n_elements_in_wei_zmm = platform::get_cache_line_size()
/ (data_type_vnni_granularity(bgmmc.wei_dt)
* bgmmc.tr_b_dt_sz);
bgmmc.wei_n_blk = rnd_up(bgmmc.N_blk, n_elements_in_wei_zmm);
} else {
bgmmc.wei_n_blk = bgmmc.N_blk;
}
VCHECK_BG(bm_conf_utils.update_and_check_B_tag(
weights_md, bgmmc.wei_n_blk, helper),
VERBOSE_UNSUPPORTED_TAG);
bgmmc.req_wei_vnni_downconvert
= bm_conf_utils.wei_down_convert_to_vnni();
}
// This setting must be updated post blocking as it has a dependency on
// `bgmmc.K_blk`. See `gK_and_K_blk_are_divisible` comment.
if (bgmmc.is_wei_scale_per_k) {
const auto gK = bgmmc.wei_scales_k_group_size;
bgmmc.gK_and_K_blk_are_divisible = gK > 1
&& ((bgmmc.K_blk % gK == 0) || (gK % bgmmc.K_blk == 0));
}
VCHECK_BG(bm_conf_utils.set_B_flags(weights_md), VERBOSE_BLOCKING_FAIL, "");
bgmmc.M_tail = bgmmc.is_runtime_M ? 0 : bgmmc.M % bgmmc.M_blk;
bgmmc.N_tail = bgmmc.is_runtime_N ? 0 : bgmmc.N % bgmmc.N_blk;
bgmmc.K_tail = bgmmc.K > bgmmc.K_blk
? ((bgmmc.extendable_k || bgmmc.use_fused_copy_a)
? bgmmc.K % bgmmc.K_blk
: rnd_up(bgmmc.K % bgmmc.K_blk,
bgmmc.required_k_granularity))
: 0;
bgmmc.LDB = bm_conf_utils.get_actual_LDB();
VCONDCHECK_BG(!(bgmmc.LDB < bgmmc.N_blk && bgmmc.N_blk % bgmmc.LDB != 0
&& bgmmc.N_blk != bgmmc.N),
"The first coordinate of every N_blk that is larger than LDB "
"needs to be divisible by LDB");
bgmmc.LDD = dst_d.ndims() == 2 && bgmmc.M == 1
? bgmmc.N
: dst_d.blocking_desc().strides[bgmmc.ndims - 2];
bgmmc.LDC = bgmmc.use_buffer_c && bgmmc.nthr_k <= 1
? (bgmmc.is_amx ? nstl::min((dim_t)32, bgmmc.N_blk) : bgmmc.N_blk)
* (bgmmc.is_runtime_N ? bgmmc.N_chunk_size : 1)
: bgmmc.LDD;
bgmmc.is_src_batch_layout_trivial
= is_batch_layout_trivial(src_d, bgmmc.batch);
bgmmc.is_wei_batch_layout_trivial
= is_batch_layout_trivial(weights_d, bgmmc.batch);
bgmmc.is_dst_batch_layout_trivial
= is_batch_layout_trivial(dst_d, bgmmc.batch);
// Sets things related to chunks and others
init_aux_values(bgmmc, src_d, weights_d, dst_d);
if (!bgmmc.is_gemv && bm_conf_utils.is_f32()
&& is_superset(bgmmc.isa, avx512_core)) {
// Dispatch the shapes with small K to gemm for better performance
// The heuristic values are empirical
const bool small_K = bgmmc.N <= 14528
&& ((bgmmc.M <= 768 && bgmmc.K <= 128)
|| bgmmc.K * bgmmc.M <= 49152);
VCONDCHECK_BG(IMPLICATION(bgmmc.ndims == 2,
!small_K || !can_use_gemm_fallback()),
VERBOSE_SMALL_SHAPES);
}
bgmmc.use_buffer_reduce
= (bgmmc.reduce_dt != data_type::f32) || (bgmmc.nthr_k > 1);
const dim_t max_a_stride = bgmmc.M_blk
* (bgmmc.use_buffer_a ? bgmmc.copy_A_src_stride
: bgmmc.LDA * bgmmc.a_dt_sz);
const dim_t max_b_stride = bgmmc.K_blk
* (bgmmc.use_buffer_b ? bgmmc.copy_B_wei_stride
: bgmmc.LDB * bgmmc.b_dt_sz);
const dim_t max_c_stride = bgmmc.M_blk * bgmmc.LDC * bgmmc.c_dt_sz;
const dim_t max_d_stride = bgmmc.M_blk * bgmmc.LDD * bgmmc.acc_dt_sz;
const dim_t max_supported_stride = std::numeric_limits<int32_t>::max();
VCONDCHECK_BG(max_a_stride <= max_supported_stride,
VERBOSE_UNSUPPORTED_FEATURE,
"src stride > INT32_MAX is not supported");
VCONDCHECK_BG(max_b_stride <= max_supported_stride,
VERBOSE_UNSUPPORTED_FEATURE,
"weights stride > INT32_MAX is not supported");
VCONDCHECK_BG(std::max(max_c_stride, max_d_stride) <= max_supported_stride,
VERBOSE_UNSUPPORTED_FEATURE,
"dst stride > INT32_MAX is not supported");
// When is_wei_batch_layout_trivial is true, we only support that
// batch offset can be divided by 2
if (bgmmc.is_int4_weights) {
VCONDCHECK_BG(IMPLICATION(bgmmc.is_wei_batch_layout_trivial
&& bgmmc.batch > 1,
bgmmc.B_strides[2] % 2 == 0),
VERBOSE_BAD_PARAM, "B_strides");
}
// init mem advice heuristic based on bmn threads and excution scan order
if (is_superset(isa, avx10_2_512)) mem_advice_init(bgmmc);
// Dispatch small shapes to VNNI for better performance
const bool runtime_dims
= bgmmc.is_runtime_M || bgmmc.is_runtime_N || bgmmc.is_runtime_K;
bool is_small_shapes = bgmmc.is_amx && !runtime_dims;
// Disable 'small_shape' heuristic for amx_fp16 until it is validated with
// performance measurements.
is_small_shapes = is_small_shapes && (bgmmc.isa != avx512_core_amx_fp16);
if (bm_conf_utils.is_bf16() || bm_conf_utils.is_f16()
|| bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()
|| bm_conf_utils.is_bf16_with_int_wei()
|| bm_conf_utils.is_f16_with_int_wei()) {
// empirical observation for performance breakpoint between amx and vnni
// bf16/f16
const dim_t buffer_a_chunk_sz_limit = 126;
is_small_shapes = is_small_shapes
&& bgmmc.buffer_a_gb_stride <= buffer_a_chunk_sz_limit;
} else if (bm_conf_utils.is_f8() || bm_conf_utils.is_bf8()) {
is_small_shapes = false;
} else {
is_small_shapes = is_small_shapes && bgmmc.ndims < 3
&& ((bgmmc.M == 1 && bgmmc.K == 256)
|| (bgmmc.M <= 32 && bgmmc.M * bgmmc.N <= 256)
|| bgmmc.K <= 16);
}
// This is the only implementation that support the packed_sparse_weights
// case therefore there is no fallback for it.
is_small_shapes = is_small_shapes && !bgmmc.packed_sparse_weights;
VCONDCHECK_BG(!is_small_shapes, VERBOSE_SMALL_SHAPES);
if (bgmmc.use_buffer_b) {
// If B is copied to a temporary buffer then the layout of B is
// [n = n_blk / LDB][k = k_blk / wei_k_blk][k = wei_k_blk / vnni][n = LDB][k = vnni]
bgmmc.LDB2 = rnd_up(bgmmc.K_blk, bgmmc.wei_k_blk) * bgmmc.LDB;
} else {
// [n = N / LDB][k = K / wei_k_blk][k = wei_k_blk / vnni][n = LDB][k = vnni]
bgmmc.LDB2 = rnd_up(bgmmc.K, bgmmc.wei_k_blk) * bgmmc.LDB;
}
return status::success;
}
status_t init_conf(brgemm_matmul_conf_t &conf, dim_t batch, dim_t M, dim_t K,
dim_t N, dim_t in_ld, dim_t n_blk, data_type_t in_type,
data_type_t out_type, format_tag_t in_tag) {
if (n_blk <= 0 && M <= 0) return status::invalid_arguments;
const auto vnni_granularity = data_type_vnni_granularity(out_type);
if (vnni_granularity <= 0) return status::invalid_arguments;
// Zero initialize the `conf` to avoid access to 'garbage' in members.
conf = brgemm_matmul_conf_t();
const bool is_bf16_with_int_wei = out_type == data_type::bf16
&& utils::one_of(in_type, data_type::s8, data_type::u8,
data_type::s4, data_type::u4);
const bool with_wei_decompression = in_type != out_type
&& utils::one_of(in_type, data_type::s8, data_type::u8,
data_type::s4, data_type::u4);
const bool is_copyB = N > 0;
conf.isa = get_max_cpu_isa(); // Just use the best ISA possible.
conf.is_bf32 = false;
conf.batch = batch;
conf.src_dt = conf.wei_dt = out_type;
conf.orig_src_dt = conf.orig_wei_dt = in_type;
// Note: will need to change `tr_a_dt_sz` for copyA in cases where src_dt != dst_dt
conf.a_dt_sz = conf.tr_a_dt_sz = types::data_type_size(conf.src_dt);
conf.N = N;
conf.M = M;
conf.K = K;
const dim_t copyA_K_blk = isa_num_vregs(conf.isa) / 2;
const dim_t copyB_K_blk = 16 * vnni_granularity;
conf.K_blk = is_copyB ? copyB_K_blk : copyA_K_blk;
conf.K_tail = conf.K % conf.K_blk;
if (!is_copyB) {
// Note: current implementation always calls the transposed kernel.
conf.transposed_A = true;
conf.M_blk = (dim_t)isa_max_vlen(conf.isa) / conf.a_dt_sz;
conf.M_tail = conf.M % conf.M_blk;
conf.copy_A_src_stride = in_ld * conf.a_dt_sz;
// setting LDA parameter required for plain transpose
conf.LDA = conf.K;
// jit_brgemm_matmul_copy_a_tranposed_impl_t::dst_stride
dim_t dst_stride = conf.LDA * conf.tr_a_dt_sz;
dim_t max_src_encode_stride = conf.K_blk * conf.copy_A_src_stride;
dim_t max_dst_encode_stride = conf.M_blk * dst_stride;
// Cannot encode EVEX compressed addresses
VCONDCHECK_BG(std::max(max_src_encode_stride, max_dst_encode_stride)
<= std::numeric_limits<int32_t>::max(),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
} else {
conf.blocked_B = !utils::one_of(in_tag, ab, ba, abc, acb);
conf.transposed_B = utils::one_of(in_tag, ba, acb);
conf.is_bf16_with_int_wei = is_bf16_with_int_wei;
conf.with_wei_decompression = with_wei_decompression;
conf.wei_tag = in_tag;
conf.wei_n_blk = conf.N_blk = conf.LDB = n_blk;
conf.N_tail = conf.N % conf.N_blk;
conf.b_dt_sz = types::data_type_size(in_type);
conf.tr_b_dt_sz = types::data_type_size(conf.wei_dt);
conf.copy_B_wei_stride = in_ld * conf.b_dt_sz;
conf.N_chunk_elems = conf.N; // To match seems unneeded assert.
conf.s8s8_comp_b_str = utils::rnd_up(conf.N, conf.wei_n_blk);
conf.s8s8_comp_n_str = conf.wei_n_blk;
dim_t max_wei_encode_off = conf.K_blk * conf.copy_B_wei_stride
+ conf.wei_n_blk * conf.b_dt_sz;
dim_t max_dst_encode_off
= (conf.K_blk * conf.LDB + conf.wei_n_blk) * conf.tr_b_dt_sz;
// Cannot encode EVEX compressed addresses
VCONDCHECK_BG(std::max(max_wei_encode_off, max_dst_encode_off)
<= std::numeric_limits<int32_t>::max(),
VERBOSE_UNSUPPORTED_MEM_STRIDE);
}
// The following members are different from the upper level `init_conf()`
// call from the reorder implementation due to lacking a memory descriptor
// to tip on compensation.
// TODO: re-consider an interface change to enable these members.
conf.s8s8_compensation_required = false;
conf.src_zp_type = brgemm_broadcast_t::none;
conf.has_zero_point_a = false;
conf.has_zero_point_b = false;
conf.is_thread_chunks_exec_order_horizontal = true;
conf.mem_advice
= brgemm_kernel_hint_mem_advice_t::brgemm_hint_mem_advice_undef;
return status::success;
}
void init_aux_values(brgemm_matmul_conf_t &bgmmc,
const memory_desc_wrapper &src_d, const memory_desc_wrapper &wei_d,
const memory_desc_wrapper &dst_d) {
bgmmc.M_chunk_elems = bgmmc.M_blk * bgmmc.M_chunk_size;
bgmmc.N_chunk_elems = bgmmc.N_blk * bgmmc.N_chunk_size;
bgmmc.K_chunk_elems
= bgmmc.K_blk * bgmmc.K_chunk_size * bgmmc.brgemm_batch_size;
bgmmc.M_chunks = div_up(bgmmc.M, bgmmc.M_chunk_elems);
bgmmc.N_chunks = div_up(bgmmc.N, bgmmc.N_chunk_elems);
bgmmc.K_chunks = div_up(bgmmc.K, bgmmc.K_chunk_elems);
bgmmc.num_M_blocks = div_up(bgmmc.M, bgmmc.M_blk);
bgmmc.num_N_blocks = div_up(bgmmc.N, bgmmc.N_blk);
bgmmc.num_K_blocks = div_up(bgmmc.K, bgmmc.K_blk * bgmmc.brgemm_batch_size);
const int last_chunck_batch_size
= (nstl::max(bgmmc.K, bgmmc.K_blk)
- (bgmmc.K_chunks - 1) * bgmmc.K_chunk_elems)
/ bgmmc.K_blk;
bgmmc.brgemm_batch_tail_size
= last_chunck_batch_size % bgmmc.brgemm_batch_size;
if (!bgmmc.is_runtime_N && bgmmc.is_amx && bgmmc.nthr_k == 1) {
bgmmc.buffer_c_chunk_sz = rnd_up(bgmmc.N_blk, bgmmc.LDC) * bgmmc.M_blk
* bgmmc.acc_dt_sz;
} else {
bgmmc.buffer_c_chunk_sz = bgmmc.acc_dt_sz
* (bgmmc.is_runtime_N ? bgmmc.N_blk : bgmmc.LDC)
* (bgmmc.nthr_k > 1 ? bgmmc.M : bgmmc.M_blk);
}
if (!bgmmc.use_buffer_c) {
// No need for C buffer
bgmmc.buffer_c_per_thread_sz = 0;
} else if (bgmmc.nthr_k > 1) {
// c size == M * N (for reduction)
bgmmc.buffer_c_per_thread_sz = bgmmc.buffer_c_chunk_sz;
} else if (!bgmmc.is_runtime_N && !bgmmc.is_runtime_M
&& bgmmc.K_chunk_elems >= bgmmc.K) {
// c size == BRGEMM size
bgmmc.buffer_c_per_thread_sz = bgmmc.buffer_c_chunk_sz;
} else {
// c size == chunk size
bgmmc.buffer_c_per_thread_sz = bgmmc.buffer_c_chunk_sz
* bgmmc.M_chunk_size * bgmmc.N_chunk_size;
}
bgmmc.buffer_a_gb_stride = bgmmc.tr_a_dt_sz * bgmmc.M_blk
* (bgmmc.use_buffer_a_tail_only ? bgmmc.wei_k_blk : bgmmc.LDA);
bgmmc.buffer_a_k_stride
= bgmmc.buffer_a_gb_stride * bgmmc.brgemm_batch_size;
bgmmc.buffer_a_m_stride = bgmmc.buffer_a_k_stride * bgmmc.K_chunk_size;
bgmmc.buffer_a_per_thread_sz = bgmmc.buffer_a_m_stride * bgmmc.M_chunk_size;
// Layout of a single GB in packed format:
// [n = n_blk / LDB][k = k_blk / wei_k_blk][k = wei_k_blk / vnni][n = LDB][k = vnni]
// For n = 0 and K aligned to vnni granularity, the starting position is:
// [K / vnni][n = 0][k_vnni = 0]
// Each copy kernel processes wei_scales_gK * n elements.
// The kernel accounts for LDB2 and writes data in a strided manner.
// The following value is multiplied by K to compute the starting offset:
bgmmc.buffer_b_k_stride = bgmmc.tr_b_dt_sz * bgmmc.LDB;
// The total usable data in one GB is:
// n_blk * k_blk, though both values require rounding.
bgmmc.buffer_b_gb_stride = bgmmc.tr_b_dt_sz * rnd_up(bgmmc.N_blk, bgmmc.LDB)
* rnd_up(bgmmc.K_blk, bgmmc.wei_k_blk);
// Each BRGEMM operation consumes k_blk * n_blk * brgemm_batch_size elements from B.
bgmmc.buffer_b_k_brg_stride
= bgmmc.buffer_b_gb_stride * bgmmc.brgemm_batch_size;
// A full K chunk of B is stored in a temporary buffer for reuse.
// Often, the size k_blk * n_blk * brgemm_batch_size * K_chunk_size ≈ L2 cache size,
// enabling reuse across the M dimension.
bgmmc.buffer_b_per_thread_sz
= bgmmc.buffer_b_k_brg_stride * bgmmc.K_chunk_size;
bgmmc.buffer_reduce_per_thread_sz = 0;
if (bgmmc.reduce_kind == matmul_reduce_kind::src) {
assert(bgmmc.acc_dt == f32);
bgmmc.buffer_reduce_per_thread_sz = bgmmc.M * bgmmc.acc_dt_sz;
}
bgmmc.s8s8_comp_ithr_str
= bgmmc.use_buffer_b ? bgmmc.wei_n_blk * bgmmc.N_chunk_size : 0;
bgmmc.s8s8_comp_b_str = bgmmc.use_buffer_b
? 0
: div_up(bgmmc.N, bgmmc.wei_n_blk) * bgmmc.wei_n_blk;
bgmmc.s8s8_comp_n_str = bgmmc.wei_n_blk;
bgmmc.A_ptr_shift_b = 0;
bgmmc.copy_A_src_stride = bgmmc.a_dt_sz
* src_d.strides()[bgmmc.ndims - 2 + bgmmc.transposed_A];
// If src have dimensions equal to 1, multiple tags can be matched so
// we need to make sure:
// - A_ptr_shift_b is set for acbd and adbc even if bgmmc.src_tag is abcd
// - Plain md that matches acbd or adbc does not dispatch into their codepath
if (src_d.matches_one_of_tag(acbd, adbc)) {
if (src_d.matches_one_of_tag(abcd, abdc) == format_tag::undef) {
const dim_t factor = bgmmc.src_dt == f32 ? 2 : 1;
const dim_t src_stride = src_d.matches_tag(acbd)
? bgmmc.A_strides[1]
: bgmmc.A_strides[0];
const dim_t copy_A_src_stride = src_d.matches_tag(dabc)
&& bgmmc.K * bgmmc.batch
== src_d.blocking_desc().strides[0]
? src_d.blocking_desc().strides[0]
: src_d.blocking_desc().strides[0] * bgmmc.K;
bgmmc.copy_A_src_stride
= nstl::min(copy_A_src_stride, src_stride / factor)
* factor;
}
const dim_t bcast_shift_b = src_d.matches_tag(acbd) ? bgmmc.K : bgmmc.M;
bgmmc.A_ptr_shift_b
= (bgmmc.bcast_A_desc.bcast_mask == 2
? bcast_shift_b
: src_d.blocking_desc().strides[0])
* bgmmc.a_dt_sz;
}
bgmmc.B_ptr_shift_b = 0;
bgmmc.copy_B_wei_stride = 0;
// If weights have dimensions equal to 1, multiple tags can be matched so
// we need to make sure:
// - B_ptr_shift_b is set for acbd and adbc even if bgmmc.wei_tag is abcd
// - Plain md that matches acbd or adbc does not dispatch into their codepath
// - Plain md that matches transposed tag does not dispatch into its codepath
if (wei_d.matches_one_of_tag(acbd, adbc) != format_tag::undef) {
const dim_t bcast_shift_b = wei_d.matches_tag(acbd) ? bgmmc.N : bgmmc.K;
bgmmc.B_ptr_shift_b
= (bgmmc.bcast_B_desc.bcast_mask == 2
? bcast_shift_b
: wei_d.blocking_desc().strides[0])
* bgmmc.b_dt_sz;
}
if (wei_d.matches_one_of_tag(acbd, adbc) != format_tag::undef
&& wei_d.matches_one_of_tag(abcd, abdc) == format_tag::undef) {
const dim_t factor = bgmmc.wei_dt == f32 ? 2 : 1;
const dim_t wei_stride = wei_d.matches_tag(acbd) ? bgmmc.B_strides[1]
: bgmmc.B_strides[0];
bgmmc.copy_B_wei_stride = nstl::min(wei_d.blocking_desc().strides[0],
wei_stride / factor)
* factor;
} else if (bgmmc.transposed_B) {
if (wei_d.strides()[bgmmc.ndims - 1] == 1) {
const auto b_stride_elems
= bgmmc.req_wei_vnni_downconvert ? bgmmc.LDB : bgmmc.N;
bgmmc.copy_B_wei_stride = b_stride_elems * bgmmc.b_dt_sz;
} else {
bgmmc.copy_B_wei_stride
= wei_d.strides()[bgmmc.ndims - 1] * bgmmc.b_dt_sz;
}
} else if (bgmmc.is_runtime_N) {
bgmmc.copy_B_wei_stride = bgmmc.N;
} else if (bgmmc.blocked_B) {
bgmmc.copy_B_wei_stride = (bgmmc.LDB * bgmmc.b_dt_sz);
} else {
bgmmc.copy_B_wei_stride
= (wei_d.strides()[bgmmc.ndims - 2] * bgmmc.b_dt_sz);
}
bgmmc.C_ptr_shift_b = dst_d.matches_one_of_tag(acbd)
? dst_d.blocking_desc().strides[0] * bgmmc.c_dt_sz
: 0;
bgmmc.has_zero_point_a = bgmmc.src_zp_type != brgemm_broadcast_t::none;
bgmmc.has_zero_point_b = bgmmc.wei_zp_type != brgemm_broadcast_t::none;
bgmmc.has_zero_point_c = bgmmc.dst_zp_type != brgemm_broadcast_t::none;
bgmmc.post_ops_applicable = one_of(true, bgmmc.with_sum, bgmmc.with_bias,
(bgmmc.with_src_scales || bgmmc.with_wei_scales)
&& !bgmmc.apply_scales_in_buffer_b,
bgmmc.with_eltwise, bgmmc.with_binary, bgmmc.acc_dt != bgmmc.dst_dt,
bgmmc.s8s8_compensation_required, bgmmc.has_zero_point_a,
bgmmc.has_zero_point_b && !bgmmc.with_wei_decompression,
bgmmc.has_zero_point_c, bgmmc.with_dst_scales);
bgmmc.zp_a_comp_shift_n = bgmmc.wei_n_blk;
bgmmc.zp_a_comp_elems_per_thr
= bgmmc.N_chunk_size * bgmmc.zp_a_comp_shift_n;
const int s32_elems_in_cacheline = 16;
bgmmc.zp_b_comp_result_shift_m = bgmmc.M_blk;
bgmmc.zp_b_comp_buffer_start
= bgmmc.M_chunk_size * bgmmc.zp_b_comp_result_shift_m;
bgmmc.zp_b_comp_buffer_shift_m = s32_elems_in_cacheline * bgmmc.M_blk;
bgmmc.zp_b_comp_elems_per_thr = bgmmc.M_chunk_size
* (bgmmc.zp_b_comp_result_shift_m + bgmmc.zp_b_comp_buffer_shift_m);
bgmmc.brgemm_batch_element_per_thr_sz = 16 * bgmmc.brgemm_batch_size;
}
void init_scratchpad(memory_tracking::registrar_t &scratchpad,
const brgemm_matmul_conf_t &bgmmc) {
const size_t default_data_align = sizeof(char);
if (bgmmc.brg_type == brgemm_addr)
scratchpad.book(key_brgemm_primitive_batch,
static_cast<size_t>(bgmmc.nthr)
* bgmmc.brgemm_batch_element_per_thr_sz,
sizeof(brgemm_batch_element_t), 64);
if (bgmmc.use_buffer_a || bgmmc.use_buffer_a_tail_only)
scratchpad.book(key_brgemm_primitive_buffer_a,
bgmmc.nthr * bgmmc.buffer_a_per_thread_sz, default_data_align);
if (bgmmc.use_buffer_b) {
scratchpad.book(key_brgemm_primitive_buffer_b,
bgmmc.nthr * bgmmc.buffer_b_per_thread_sz, default_data_align);
if (bgmmc.s8s8_compensation_required && (!bgmmc.blocked_B))
scratchpad.book(key_brgemm_primitive_buffer_comp,
bgmmc.nthr * bgmmc.s8s8_comp_ithr_str,
types::data_type_size(f32));
}
if (bgmmc.use_buffer_c)
scratchpad.book(key_brgemm_primitive_buffer,
bgmmc.nthr * bgmmc.buffer_c_per_thread_sz, default_data_align);
if (bgmmc.use_buffer_reduce) {
const bool is_reduce_f32 = bgmmc.reduce_dt == f32;
scratchpad.book(key_brgemm_primitive_buffer_reduce,
(bgmmc.nthr_k - is_reduce_f32)
* bgmmc.buffer_reduce_per_thread_sz,
default_data_align);
}
if (bgmmc.has_zero_point_a) {
const auto num_elems = bgmmc.nthr * bgmmc.zp_a_comp_elems_per_thr;
scratchpad.book(key_brgemm_primitive_zp_comp_a, num_elems,
types::data_type_size(s32));
}
if (bgmmc.has_zero_point_b)
scratchpad.book(key_brgemm_primitive_zp_comp_b,
bgmmc.nthr * bgmmc.zp_b_comp_elems_per_thr,
types::data_type_size(s32));
if (is_superset(bgmmc.isa, avx512_core_amx))
scratchpad.book(key_conv_amx_tile_buffer,
static_cast<size_t>(bgmmc.nthr) * bgmmc.wsp_tile_per_thr_bytes,
default_data_align);
if (bgmmc.is_runtime_M || bgmmc.is_runtime_N)
scratchpad.book(key_brgemm_primitive_buffer_d,
bgmmc.M_blk * bgmmc.N_blk * bgmmc.c_dt_sz * bgmmc.nthr,
default_data_align);
if (bgmmc.with_dst_scales) {
// See brgemm_types.hpp comment for `with_dst_scales`.
scratchpad.book(key_matmul_dst_scales,
static_cast<size_t>(bgmmc.nthr) * sizeof(float),
default_data_align);
}
}
} // namespace matmul
} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl