mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
1374 lines
53 KiB
C++
1374 lines
53 KiB
C++
/*******************************************************************************
|
|
* Copyright 2022-2025 Intel Corporation
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
#include <float.h>
|
|
#include <functional>
|
|
#include <math.h>
|
|
#include <random>
|
|
#include <stdio.h>
|
|
#include <stdlib.h>
|
|
|
|
#include "oneapi/dnnl/dnnl.h"
|
|
|
|
// TODO: refactor the driver to avoid using extra flags of a memory descriptor.
|
|
#include "src/common/memory_desc.hpp"
|
|
|
|
#include "tests/test_isa_common.hpp"
|
|
|
|
#include "utils/parallel.hpp"
|
|
#include "utils/parser.hpp"
|
|
|
|
#include "dnnl_common.hpp"
|
|
#include "dnnl_memory.hpp"
|
|
|
|
#include "brgemm/brgemm.hpp"
|
|
|
|
#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
|
|
// Need these macro independently of API.
|
|
#if defined(DNNL_X64) && DNNL_X64 == 1
|
|
#define brg_x64
|
|
#elif defined(DNNL_AARCH64) && DNNL_AARCH64 == 1
|
|
#define brg_aarch64
|
|
#endif
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
#if defined(DNNL_X64) && DNNL_X64 == 1
|
|
#define namespace_impl dnnl::impl::cpu::x64
|
|
#elif defined(DNNL_AARCH64) && DNNL_AARCH64 == 1
|
|
#define namespace_impl dnnl::impl::cpu::aarch64
|
|
// TODO: remove when `brgemm_t` type gets renamed.
|
|
using brgemm_desc_t = namespace_impl::brgemm_t;
|
|
#endif
|
|
|
|
#if defined(brg_x64) || defined(brg_aarch64)
|
|
template <>
|
|
struct dnnl_api_traits<namespace_impl::brgemm_kernel_t *> {
|
|
static void destroy(namespace_impl::brgemm_kernel_t *t) {
|
|
DNN_SAFE_V(namespace_impl::brgemm_kernel_destroy(t));
|
|
}
|
|
};
|
|
#endif
|
|
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
|
|
template <>
|
|
struct dnnl_api_traits<dnnl_brgemm_t> {
|
|
static void destroy(dnnl_brgemm_t t) { DNN_SAFE_V(dnnl_brgemm_destroy(t)); }
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_api_traits<dnnl_transform_t> {
|
|
static void destroy(dnnl_transform_t t) {
|
|
DNN_SAFE_V(dnnl_transform_destroy(t));
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct dnnl_api_traits<dnnl_ukernel_attr_params_t> {
|
|
static void destroy(dnnl_ukernel_attr_params_t t) {
|
|
DNN_SAFE_V(dnnl_ukernel_attr_params_destroy(t));
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
#endif // DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
|
|
|
|
namespace brgemm {
|
|
|
|
#if defined(brg_x64) || defined(brg_aarch64)
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
/// Initializes BRGEMM attributes from an input string.
|
|
///
|
|
/// @param brgattr Output BRGEMM attributes.
|
|
/// @param str Input string of values in the format: KEY:VALUE[+KEY:VALUE[...]].
|
|
/// `KEY` follows exact name of the brgemm_attr_t object members and their
|
|
/// `VALUE` follow the member type. enum and boolean types are treated as
|
|
/// integers.
|
|
///
|
|
dnnl_status_t brgemm_attr_init(
|
|
namespace_impl::brgemm_attr_t *brgattr, const prb_t *prb) {
|
|
using namespace namespace_impl;
|
|
|
|
// `max_bs` is handled directly through the driver interface.
|
|
brgattr->max_bs = prb->batch_size;
|
|
|
|
// `fpmath_mode` is handled directly through the driver interface.
|
|
brgattr->fpmath_mode = prb->attr.fpmath_mode.mode;
|
|
|
|
const auto &str = prb->brgemm_attr;
|
|
if (str.empty()) return dnnl_success;
|
|
|
|
size_t entry_pos = 0;
|
|
while (entry_pos != std::string::npos) {
|
|
auto key_value_str = parser::get_substr(str, entry_pos, '+');
|
|
size_t value_pos = 0;
|
|
auto key_str = parser::get_substr(key_value_str, value_pos, ':');
|
|
auto value_str = parser::get_substr(key_value_str, value_pos, '\0');
|
|
|
|
#define PROCESS_SETTING_KEY_VAL(setting, key) \
|
|
if (key_str.compare(STRINGIFY(key)) == 0) \
|
|
brgattr->setting = std::stoi(value_str);
|
|
|
|
#define PROCESS_KEY_VAL(setting) PROCESS_SETTING_KEY_VAL(setting, setting)
|
|
|
|
// TODO: `max_top_vpad` and `max_bottom_vpad` do not affect anything in
|
|
// the kernel call and reference computation so far since
|
|
// batch_element_t struct is not adjusted to incorporate different pad
|
|
// values.
|
|
// PROCESS_KEY_VAL(max_top_vpad);
|
|
// PROCESS_KEY_VAL(max_bottom_vpad);
|
|
PROCESS_KEY_VAL(hint_expected_A_size);
|
|
PROCESS_KEY_VAL(hint_expected_B_size);
|
|
PROCESS_KEY_VAL(hint_expected_C_size);
|
|
PROCESS_KEY_VAL(wary_A_k_tail_read);
|
|
PROCESS_KEY_VAL(extendable_k);
|
|
PROCESS_KEY_VAL(generate_skip_accumulation);
|
|
// TODO: `bd_mask` can't be passed to the kernel at this moment, that's
|
|
// why `bd_mask_level` has to stay `0` for now until it's enabled.
|
|
// PROCESS_KEY_VAL(bd_mask_level);
|
|
PROCESS_KEY_VAL(use_uker);
|
|
PROCESS_KEY_VAL(use_interleave_stores);
|
|
PROCESS_KEY_VAL(b_is_vnni);
|
|
PROCESS_KEY_VAL(postops_only);
|
|
PROCESS_KEY_VAL(hint_bd_block);
|
|
PROCESS_KEY_VAL(hint_bd_block2);
|
|
PROCESS_KEY_VAL(hint_ld_block);
|
|
PROCESS_KEY_VAL(hint_ld_block2);
|
|
|
|
PROCESS_SETTING_KEY_VAL(hint_prfA.dist1, hint_prfA_dist1);
|
|
PROCESS_SETTING_KEY_VAL(hint_prfA.dist2, hint_prfA_dist2);
|
|
PROCESS_SETTING_KEY_VAL(hint_prfB.dist1, hint_prfB_dist1);
|
|
PROCESS_SETTING_KEY_VAL(hint_prfB.dist2, hint_prfB_dist2);
|
|
PROCESS_SETTING_KEY_VAL(hint_prfC.dist1, hint_prfC_dist1);
|
|
PROCESS_SETTING_KEY_VAL(hint_prfC.dist2, hint_prfC_dist2);
|
|
|
|
#undef PROCESS_SETTING_KEY_VAL
|
|
#undef PROCESS_KEY_VAL
|
|
|
|
if (key_str.find(STRINGIFY(hint_innermost_loop)) != std::string::npos)
|
|
brgattr->hint_innermost_loop
|
|
= static_cast<brgemm_kernel_innermost_loop_t>(
|
|
std::stoi(value_str));
|
|
if (key_str.find(STRINGIFY(hint_loop_order)) != std::string::npos)
|
|
brgattr->hint_loop_order = static_cast<brgemm_kernel_loop_order_t>(
|
|
std::stoi(value_str));
|
|
if (key_str.find(STRINGIFY(hint_prefetching)) != std::string::npos)
|
|
brgattr->hint_prefetching
|
|
= static_cast<brgemm_kernel_prefetching_t>(
|
|
std::stoi(value_str));
|
|
if (key_str.find(STRINGIFY(hint_load_nt_A)) != std::string::npos)
|
|
brgattr->hint_load_nt_A = static_cast<brgemm_kernel_hint_nt_t>(
|
|
std::stoi(value_str));
|
|
if (key_str.find(STRINGIFY(hint_load_nt_B)) != std::string::npos)
|
|
brgattr->hint_load_nt_B = static_cast<brgemm_kernel_hint_nt_t>(
|
|
std::stoi(value_str));
|
|
}
|
|
|
|
return dnnl_success;
|
|
}
|
|
|
|
std::string prepare_wei_format_string(
|
|
dnnl_data_type_t dt, int64_t ldb, bool is_vnni_layout) {
|
|
// `dt` affects the choice of last inner block (for VNNI-friendliness).
|
|
// `n` affects the choice of B block.
|
|
std::string wtag("BA16a");
|
|
wtag += std::to_string(ldb) + "b";
|
|
if (is_vnni_layout) {
|
|
switch (dt) {
|
|
case dnnl_f32: break;
|
|
case dnnl_f16:
|
|
case dnnl_bf16: wtag += "2a"; break;
|
|
case dnnl_f8_e5m2:
|
|
case dnnl_f8_e4m3:
|
|
case dnnl_u8:
|
|
case dnnl_s8: wtag += "4a"; break;
|
|
default: assert(!"unsupported data type");
|
|
}
|
|
}
|
|
|
|
return wtag;
|
|
}
|
|
|
|
namespace_impl::brgemm_batch_kind_t str2batch_kind(const std::string &str) {
|
|
if (str == "addr")
|
|
return namespace_impl::brgemm_batch_kind_t::brgemm_addr;
|
|
else if (str == "offs")
|
|
return namespace_impl::brgemm_batch_kind_t::brgemm_offs;
|
|
assert(!"Unsupported batch kind value");
|
|
return namespace_impl::brgemm_batch_kind_t::brgemm_batch_kind_undef;
|
|
}
|
|
#endif
|
|
|
|
int fill_data(data_kind_t kind, const prb_t *prb, const cfg_t &cfg,
|
|
dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, res_t *res) {
|
|
|
|
const auto nelems = mem_dt.nelems();
|
|
if (nelems == 0) return OK;
|
|
|
|
assert(mem_dt.nelems() == mem_fp.nelems());
|
|
|
|
if (has_bench_mode_bit(mode_bit_t::perf)) {
|
|
return fill_random_real(
|
|
mem_dt, mem_fp, res, get_perf_fill_cfg(mem_dt.dt()));
|
|
}
|
|
|
|
cfg_t::density_args_t density_args;
|
|
density_args.data_kind = kind;
|
|
density_args.n_acc = prb->k;
|
|
const auto density = cfg.get_density(density_args);
|
|
|
|
/* Do fixed partitioning to have same filling for any number of threads */
|
|
const int64_t chunk_size = 64;
|
|
const int64_t n_chunks = div_up(nelems, chunk_size);
|
|
|
|
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
|
|
int64_t idx_start = idx_chunk * chunk_size;
|
|
int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
|
|
// Note: we use a different seed for each chunk to avoid
|
|
// repeating patterns. We could use discard(idx_start) too but
|
|
// it has a complexity in O(idx_start). We also add 1 to avoid
|
|
// seeding with 0.
|
|
std::minstd_rand int_seed(kind * nelems + idx_start + 1);
|
|
int_seed.discard(1);
|
|
std::minstd_rand b_seed(kind * nelems + idx_start + 1);
|
|
b_seed.discard(10);
|
|
|
|
std::uniform_int_distribution<> gen(
|
|
cfg.get_range_min(kind), cfg.get_range_max(kind));
|
|
std::bernoulli_distribution b_dist(density);
|
|
|
|
// make sure the first element is positive
|
|
if (idx_start == 0) {
|
|
float val = 0;
|
|
while (val <= 0)
|
|
val = gen(int_seed);
|
|
mem_fp.set_f32_elem(
|
|
0, round_to_nearest_representable(cfg.get_dt(kind), val));
|
|
idx_start += 1;
|
|
}
|
|
|
|
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
|
|
bool is_one = density == 1.f ? true : b_dist(b_seed);
|
|
if (!is_one) {
|
|
mem_fp.set_f32_elem(idx, 0.f);
|
|
continue;
|
|
}
|
|
float val = gen(int_seed);
|
|
mem_fp.set_f32_elem(
|
|
idx, round_to_nearest_representable(cfg.get_dt(kind), val));
|
|
}
|
|
});
|
|
|
|
SAFE(mem_dt.reorder(mem_fp), WARN);
|
|
|
|
return OK;
|
|
}
|
|
|
|
// An object to pass information between different modules of the flow.
|
|
struct kernel_args_t {
|
|
kernel_args_t(const prb_t *prb, res_t *res)
|
|
:
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
brgemm_kernel_(nullptr)
|
|
, palette()
|
|
, is_b_data_layout_vnni_(false)
|
|
, need_tile_config_(false)
|
|
, original_wei_md_size_(0)
|
|
#else
|
|
brgemm_(nullptr)
|
|
, transform_(nullptr)
|
|
, need_pack_(0)
|
|
#endif
|
|
, scratchpad_size_(0)
|
|
, generate_skip_accumulation_(false)
|
|
, prb_(prb)
|
|
, res_(res) {
|
|
}
|
|
|
|
// Output members
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
namespace_impl::brgemm_kernel_t *brgemm_kernel_;
|
|
char palette[/*dnnl::impl::cpu::x64::AMX_PALETTE_SIZE = */ 64];
|
|
bool is_b_data_layout_vnni_;
|
|
bool need_tile_config_;
|
|
size_t original_wei_md_size_;
|
|
#else
|
|
dnnl_brgemm_t brgemm_;
|
|
dnnl_transform_t transform_;
|
|
int need_pack_; // `int` to match C API
|
|
#endif
|
|
size_t scratchpad_size_;
|
|
bool generate_skip_accumulation_;
|
|
|
|
// Input members
|
|
const prb_t *prb_;
|
|
res_t *res_;
|
|
};
|
|
|
|
int init_kernel(kernel_args_t &kernel_args) {
|
|
const prb_t *prb = kernel_args.prb_;
|
|
res_t *res = kernel_args.res_;
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
using namespace namespace_impl;
|
|
|
|
// Supports only address model for now as only affects the way memory is
|
|
// passed to `brgemm_batch_element_t` object.
|
|
brgemm_batch_kind_t batch_kind = str2batch_kind(prb->batch_kind);
|
|
brgemm_layout_t layout = brgemm_layout_t::brgemm_row_major;
|
|
|
|
// Pass `isa_undef` for now since internal work with it or rather isa bits
|
|
// than isa values directly which causes misalignment between public enum
|
|
// and internal values.
|
|
// TODO: re-consider enabling isa values.
|
|
const auto isa_undef = cpu_isa_t::isa_undef;
|
|
|
|
brgemm_desc_t brgemm_desc;
|
|
|
|
// Create BRGeMM descriptor, analogous to primitive descriptor creation
|
|
const auto status_init = brgemm_desc_init(&brgemm_desc, isa_undef,
|
|
batch_kind, prb->src_dt(), prb->wei_dt(), false /* transA */,
|
|
false /* transB */, layout, prb->alpha, prb->beta, prb->get_lda(),
|
|
prb->get_ldb(), prb->get_ldc(), prb->m, prb->n, prb->k,
|
|
nullptr /* strides */);
|
|
SAFE(check_dnnl_status(status_init, prb, res), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
attr_args_t attr_args;
|
|
attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data());
|
|
const auto &wei_scale = prb->attr.scales.get(DNNL_ARG_WEIGHTS);
|
|
if (wei_scale.policy == policy_t::PER_OC) {
|
|
attr_args.prepare_quant(
|
|
prb->attr, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, 2);
|
|
}
|
|
auto dnnl_attr = make_benchdnn_dnnl_wrapper(
|
|
create_dnnl_attr(prb->attr, attr_args));
|
|
dims_t dst_strides = {prb->get_ldd(), 1};
|
|
auto dst_md = dnn_mem_t::init_md(
|
|
prb->ndims, prb->dst_dims.data(), prb->dst_dt(), "", dst_strides);
|
|
|
|
SAFE(check_dnnl_status(brgemm_desc_set_postops(&brgemm_desc, dnnl_attr,
|
|
dst_md, prb->get_ldd(), prb->bia_dt),
|
|
prb, res),
|
|
WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
brgemm_attr_t brgemm_attr;
|
|
DNN_SAFE(brgemm_attr_init(&brgemm_attr, prb), WARN);
|
|
SAFE(check_dnnl_status(
|
|
brgemm_desc_set_attr(&brgemm_desc, brgemm_attr), prb, res),
|
|
WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
SAFE(check_dnnl_status(brgemm_desc_finalize(&brgemm_desc), prb, res), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
kernel_args.generate_skip_accumulation_
|
|
= brgemm_attr.generate_skip_accumulation;
|
|
|
|
// Create BRGeMM kernel, analogous to primitive creation.
|
|
// ctx_init can here be used to select core type on hetero ISA with TBB.
|
|
brgemm_kernel_t **brgemm_kernel_addr = &kernel_args.brgemm_kernel_;
|
|
DNN_SAFE(create_in_thr_ctx(prb->ctx_init, brgemm_kernel_create,
|
|
brgemm_kernel_addr, brgemm_desc),
|
|
WARN);
|
|
|
|
#if defined(brg_x64)
|
|
// Palette configuration is required here to have `kernel_args`
|
|
// initialization consoidated in a single place.
|
|
const auto init_tiles_st
|
|
= brgemm_init_tiles(brgemm_desc, kernel_args.palette);
|
|
if (init_tiles_st == dnnl_success) { kernel_args.need_tile_config_ = true; }
|
|
#endif
|
|
|
|
kernel_args.is_b_data_layout_vnni_ = brgemm_desc.is_b_data_layout_vnni();
|
|
kernel_args.scratchpad_size_ = brgemm_desc.get_wsp_buffer_size();
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
attr_args_t attr_args;
|
|
attr_args.prepare_post_ops_mds(prb->attr, prb->ndims, prb->dst_dims.data());
|
|
auto dnnl_attr = make_benchdnn_dnnl_wrapper(
|
|
create_dnnl_attr(prb->attr, attr_args));
|
|
auto dnnl_post_ops = query_post_ops(dnnl_attr);
|
|
|
|
dnnl_status_t st = dnnl_success;
|
|
auto &brgemm = kernel_args.brgemm_;
|
|
DNN_SAFE(
|
|
dnnl_brgemm_create(&brgemm, prb->m, prb->n, prb->k, prb->batch_size,
|
|
prb->get_lda(), prb->get_ldb(), prb->get_ldc(),
|
|
prb->src_dt(), prb->wei_dt(), prb->acc_dt()),
|
|
WARN);
|
|
// Only `beta` equal to `0.f` and `1.f` works.
|
|
DNN_SAFE(dnnl_brgemm_set_add_C(brgemm, static_cast<int>(prb->beta)), WARN);
|
|
DNN_SAFE(dnnl_brgemm_set_post_ops(
|
|
brgemm, prb->get_ldd(), prb->dst_dt(), dnnl_post_ops),
|
|
WARN);
|
|
if (!prb->attr.scales.is_def(DNNL_ARG_SRC)) {
|
|
DNN_SAFE(dnnl_brgemm_set_A_scales(
|
|
brgemm, prb->attr.scales.get_mask(DNNL_ARG_SRC)),
|
|
WARN);
|
|
}
|
|
if (!prb->attr.scales.is_def(DNNL_ARG_WEIGHTS)) {
|
|
DNN_SAFE(dnnl_brgemm_set_B_scales(
|
|
brgemm, prb->attr.scales.get_mask(DNNL_ARG_WEIGHTS)),
|
|
WARN);
|
|
}
|
|
if (!prb->attr.scales.is_def(DNNL_ARG_DST)) {
|
|
DNN_SAFE(dnnl_brgemm_set_D_scales(
|
|
brgemm, prb->attr.scales.get_mask(DNNL_ARG_DST)),
|
|
WARN);
|
|
}
|
|
// This call is responsible whether the final configuration is supported
|
|
// or not.
|
|
st = dnnl_brgemm_finalize(brgemm);
|
|
SAFE(check_dnnl_status(st, prb, res), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
dnnl_pack_type_t pack_type = dnnl_pack_type_undef;
|
|
DNN_SAFE(dnnl_brgemm_get_B_pack_type(
|
|
&pack_type, prb->src_dt(), prb->wei_dt()),
|
|
WARN);
|
|
kernel_args.need_pack_ = pack_type == dnnl_pack_type_pack32;
|
|
|
|
DNN_SAFE(dnnl_brgemm_generate(brgemm), WARN);
|
|
DNN_SAFE(dnnl_brgemm_get_scratchpad_size(
|
|
brgemm, &kernel_args.scratchpad_size_),
|
|
WARN);
|
|
|
|
if (kernel_args.need_pack_) {
|
|
// Create a memory desc based on user inputs and query strides to use
|
|
// them in a pack routine.
|
|
const dnnl_dims_t wei_dims = {prb->k * prb->batch_size, prb->n};
|
|
auto wei_md = dnn_mem_t::init_md(prb->ndims, wei_dims, prb->wei_dt(),
|
|
prb->wtag, prb->strides[STRIDES_WEI]);
|
|
const auto &wei_strides = query_md_strides(wei_md);
|
|
assert(query_md_ndims(wei_md) == 2);
|
|
|
|
auto &transform = kernel_args.transform_;
|
|
// Choose `no_trans` for cases when K = 1 as less memory is required.
|
|
auto in_pack_type = wei_strides[1] > wei_strides[0]
|
|
? dnnl_pack_type_trans
|
|
: dnnl_pack_type_no_trans;
|
|
// One of strides implicitly equals to `1`.
|
|
auto in_ld = MAX2(wei_strides[0], wei_strides[1]);
|
|
st = dnnl_transform_create(&transform, prb->k * prb->batch_size, prb->n,
|
|
in_pack_type, in_ld, prb->get_ldb(), prb->wei_dt(),
|
|
prb->wei_dt());
|
|
SAFE(check_dnnl_status(st, prb, res), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
DNN_SAFE(dnnl_transform_generate(transform), WARN);
|
|
}
|
|
|
|
// Unneeded from API perspective, it's needed for reference.
|
|
kernel_args.generate_skip_accumulation_ = false;
|
|
#endif
|
|
return OK;
|
|
}
|
|
|
|
void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
|
|
auto is_xf16 = [](dnnl_data_type_t dt) {
|
|
return dt == dnnl_bf16 || dt == dnnl_f16;
|
|
};
|
|
if (!IMPLICATION(is_xf16(prb->bia_dt) || is_xf16(prb->dst_dt()),
|
|
is_xf16(prb->wei_dt()))) {
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
skip_unimplemented_data_type(
|
|
{prb->src_dt(), prb->wei_dt(), prb->bia_dt, prb->dst_dt()},
|
|
prb->dir, res);
|
|
skip_unimplemented_sum_po(
|
|
prb->attr, res, dnnl_gemm, prb->src_dt(), prb->dst_dt());
|
|
skip_unimplemented_binary_po(prb->attr, res);
|
|
skip_unimplemented_prelu_po(prb->attr, res, dnnl_gemm);
|
|
|
|
// Unconditionally skip remaining unimplemented cases.
|
|
// TODO: stop doing it.
|
|
BENCHDNN_PRINT(
|
|
2, "%s\n", "The kernel return unimplemented by some reason.");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
}
|
|
|
|
void skip_invalid_prb(const prb_t *prb, res_t *res) {
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// Reorder does not support s8 and zp compensations for arbitrary shapes,
|
|
// so skip unsupported cases.
|
|
// Note: this check must be done here to avoid runtime error in benchdnn due
|
|
// to failed reorder creation.
|
|
// TODO: enable this support and remove this check.
|
|
const bool is_bad_ldb = prb->get_ldb() % 16 > 0 || prb->get_ldb() > 64;
|
|
const bool req_s8_comp = prb->src_dt() == dnnl_s8;
|
|
const bool req_zp_comp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC);
|
|
if (is_bad_ldb && (req_s8_comp || req_zp_comp)) {
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"Reorder with compensation is not supported for a given LDB");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (!prb->attr.zero_points.is_def(DNNL_ARG_WEIGHTS)) {
|
|
// TODO: weights zero point is not supported yet.
|
|
// It requires enabling f32 -> u8 reorder with compensation on the
|
|
// library side. When enabled, it produces incorrect results for cases
|
|
// with K=1. Likely there's a bug inside. Postpone supporting it.
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"Reorder with compensation is not supported for u8 destination "
|
|
"data type");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (prb->wtag != tag::abx) {
|
|
BENCHDNN_PRINT(
|
|
2, "%s\n", "`wtag` option is supported for ukernel API only.");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (!prb->strides[STRIDES_WEI].empty()) {
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"`strides` option for weights is supported for ukernel API "
|
|
"only.");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
#else
|
|
if (!prb->attr.is_def()) {
|
|
bool non_def_zps = !prb->attr.zero_points.is_def();
|
|
bool non_def_fpmath = !prb->attr.fpmath_mode.is_def();
|
|
if (non_def_zps || non_def_fpmath) {
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"Non-default scales/zero-points/fpmath attributes are not "
|
|
"supported");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
bool non_def_po = !prb->attr.post_ops.is_def();
|
|
if (non_def_po) {
|
|
const auto &po = prb->attr.post_ops;
|
|
bool has_sum = po.find(attr_t::post_ops_t::kind_t::SUM) != -1;
|
|
if (has_sum) {
|
|
BENCHDNN_PRINT(2, "%s\n", "Sum post-op is not supported");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
|
|
const bool ldb_ok = prb->get_ldb() == 16 || prb->get_ldb() == 32
|
|
|| prb->get_ldb() == 48 || prb->get_ldb() == 64;
|
|
if (!ldb_ok) {
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"Unsupported leading B dimension. Only 16, 32, 48, and 64 are "
|
|
"supported");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (prb->bia_dt != dnnl_data_type_undef) {
|
|
BENCHDNN_PRINT(2, "%s\n", "Bias is not supported");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (prb->src_dt() == dnnl_s8 && prb->wei_dt() == dnnl_s8) {
|
|
// Pre-AMX ISAs require s8s8 compensation buffer passed. The internals
|
|
// should check if it was supplied and don't blow up if it wasn't
|
|
// provided.
|
|
BENCHDNN_PRINT(2, "%s\n", "s8s8 support is temporary disabled");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
|
|
if (prb->alpha != 1.f) {
|
|
BENCHDNN_PRINT(2, "%s\n", "Alpha is purposely not supported");
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
|
|
const args_t &ref_args) {
|
|
const auto dt = prb->get_dt(kind);
|
|
const float trh = dt == dnnl_f32 ? 1e-6f : epsilon_dt(dt);
|
|
cmp.set_threshold(trh);
|
|
cmp.set_zero_trust_percent(90.f); // TODO: why so bad filling?
|
|
}
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// A special wrapper needed to match internal benchdnn infrastructure.
|
|
dnnl_status_t brgemm_kernel_execute_postops_wrapper(
|
|
const namespace_impl::brgemm_kernel_t *brgemm_kernel,
|
|
const std::string &batch_kind, int batch_size, const void *src_ptr,
|
|
const void *wei_ptr,
|
|
const namespace_impl::brgemm_batch_element_t *batch_element,
|
|
void *acc_ptr, void *dst_ptr,
|
|
const namespace_impl::brgemm_post_ops_data_t &post_ops_data,
|
|
void *scratchpad_ptr, const dnnl_stream_t &stream,
|
|
const std::vector<dnnl_exec_arg_t> &dnnl_args) {
|
|
|
|
if (batch_kind == "addr") {
|
|
brgemm_kernel_execute_postops(brgemm_kernel, batch_size, batch_element,
|
|
acc_ptr, dst_ptr, post_ops_data, scratchpad_ptr);
|
|
} else if (batch_kind == "offs") {
|
|
brgemm_kernel_execute_postops(brgemm_kernel, batch_size, src_ptr,
|
|
wei_ptr, batch_element, acc_ptr, dst_ptr, post_ops_data,
|
|
scratchpad_ptr);
|
|
}
|
|
return dnnl_success;
|
|
}
|
|
#else
|
|
// A special wrapper needed to match internal benchdnn infrastructure.
|
|
dnnl_status_t brgemm_kernel_execute_postops_wrapper(const_dnnl_brgemm_t brgemm,
|
|
const bool use_dst_as_acc, const void *src_ptr,
|
|
const void *wei_packed_ptr, const std::vector<dnnl_dim_t> &offsets,
|
|
void *acc_ptr, void *dst_ptr, void *scratchpad_ptr,
|
|
const_dnnl_ukernel_attr_params_t attr_params,
|
|
const dnnl_stream_t &stream,
|
|
const std::vector<dnnl_exec_arg_t> &dnnl_args) {
|
|
|
|
dnnl_status_t st = dnnl_runtime_error;
|
|
if (use_dst_as_acc) {
|
|
st = dnnl_brgemm_execute(brgemm, src_ptr, wei_packed_ptr,
|
|
offsets.data(), dst_ptr, scratchpad_ptr);
|
|
} else {
|
|
st = dnnl_brgemm_execute_postops(brgemm, src_ptr, wei_packed_ptr,
|
|
offsets.data(), acc_ptr, dst_ptr, scratchpad_ptr, attr_params);
|
|
}
|
|
return st;
|
|
}
|
|
#endif
|
|
|
|
// `init_memory_args` is responsible for:
|
|
// * Constructing all necessary `dnn_mem_t` objects needed by the brgemm kernel
|
|
// for the main operation and attributes.
|
|
// * Stashing them with a proper exec_arg ID in a `mem_map` object.
|
|
// See a common version of `init_memory_args` comment for more details.
|
|
void init_memory_args(
|
|
dnn_mem_map_t &mem_map, const prb_t *prb, kernel_args_t &kernel_args) {
|
|
// Fuse batch size into K dimension which follows the library usage of the
|
|
// kernel batch size setting.
|
|
const dnnl_dims_t src_dims = {prb->m, prb->k * prb->batch_size};
|
|
const dnnl_dims_t wei_dims = {prb->k * prb->batch_size, prb->n};
|
|
|
|
dims_t src_strides = {prb->get_lda(), 1};
|
|
dims_t dst_strides = {prb->get_ldd(), 1};
|
|
dims_t acc_strides = prb->use_dst_as_acc() ? dst_strides : dims_t();
|
|
|
|
auto src_md = dnn_mem_t::init_md(
|
|
prb->ndims, src_dims, prb->src_dt(), "", src_strides);
|
|
|
|
auto dst_md = dnn_mem_t::init_md(
|
|
prb->ndims, prb->dst_dims.data(), prb->dst_dt(), "", dst_strides);
|
|
|
|
// Same as dst_md but with a pre-defined data type according to doc.
|
|
auto acc_md = dnn_mem_t::init_md(prb->ndims, prb->dst_dims.data(),
|
|
prb->acc_dt(), tag::abx, acc_strides);
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// Create weights memory descriptor with VNNI-friendly format.
|
|
// Note: LDB is not passed here. This is because it's super difficult to
|
|
// incorporate stride on top of blocking - oneDNN API doesn't provide any
|
|
// calls to support both options together. Submemory descriptor, which is
|
|
// the only one who can create such memory desc, can't return the size of
|
|
// memory. Thus, it requires two memories and we need to pass a memory
|
|
// handle from bigger one (where LDB is an actual dim value) to smaller, but
|
|
// there's some reorder bug resulting in an error.
|
|
const auto wtag = prepare_wei_format_string(
|
|
prb->wei_dt(), prb->get_ldb(), kernel_args.is_b_data_layout_vnni_);
|
|
BENCHDNN_PRINT(6, "wtag: %s\n", wtag.c_str());
|
|
|
|
auto wei_md = dnn_mem_t::init_md(prb->ndims, wei_dims, prb->wei_dt(), wtag);
|
|
kernel_args.original_wei_md_size_ = dnnl_memory_desc_get_size(wei_md);
|
|
|
|
// Prepare and assign extra for wei_md when s8s8 compensation, or source
|
|
// zero point reduction values are needed.
|
|
dnnl::impl::memory_extra_desc_t wei_md_extra {};
|
|
wei_md_extra.flags = dnnl::impl::memory_extra_flags::none;
|
|
if (prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8) {
|
|
wei_md_extra.flags
|
|
|= dnnl::impl::memory_extra_flags::compensation_conv_s8s8;
|
|
wei_md_extra.compensation_mask = 2; // N dimension
|
|
}
|
|
static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra;
|
|
|
|
const bool need_src_comp = !prb->attr.zero_points.is_def(DNNL_ARG_SRC);
|
|
if (need_src_comp) {
|
|
wei_md_extra.flags |= dnnl::impl::memory_extra_flags::
|
|
compensation_conv_asymmetric_src;
|
|
wei_md_extra.asymm_compensation_mask = 2; // N dimension
|
|
}
|
|
static_cast<dnnl_memory_desc_t>(wei_md)->extra = wei_md_extra;
|
|
|
|
benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> bia_md {};
|
|
if (prb->bia_dt != dnnl_data_type_undef) {
|
|
const dnnl_dims_t bia_dims = {1, prb->n};
|
|
bia_md = dnn_mem_t::init_md(
|
|
prb->ndims, bia_dims, prb->bia_dt, tag::abx);
|
|
}
|
|
#else
|
|
auto wei_md = dnn_mem_t::init_md(prb->ndims, wei_dims, prb->wei_dt(),
|
|
prb->wtag, prb->strides[STRIDES_WEI]);
|
|
const auto &wei_strides = query_md_strides(wei_md);
|
|
|
|
// Note: packing routine transforms a plain user tensor into an internal
|
|
// blocking format with various JIT kernels depending on user inputs.
|
|
// While some kernels working fine with less memory, some don't, such as
|
|
// transposed `ba` format.
|
|
//
|
|
// To supply enough memory for the transformation, the following logic
|
|
// adjusts memory amount based on `simd_w` and `dt_multiplier` same way
|
|
// what physical padding would do.
|
|
//
|
|
// `dt_multiplier` is required to form a 32-bit element which is a basic
|
|
// unit of BRGeMM computations. There's the only outlier - f16 on ISAs where
|
|
// packing is not expected, it acts as f32 there.
|
|
const int dt_multiplier
|
|
= prb->wei_dt() == dnnl_f16 && !kernel_args.need_pack_
|
|
? 1
|
|
: 4 / dnnl_data_type_size(prb->wei_dt());
|
|
|
|
int multiplier = 1;
|
|
if (kernel_args.need_pack_) {
|
|
multiplier = dt_multiplier;
|
|
// Note (impl detail): transposed kernel wants 64 bytes for K dim.
|
|
if (wei_strides[0] < wei_strides[1]) {
|
|
// Though `simd_w` is not necessarily `16` on all ISAs, it's for
|
|
// simplicity.
|
|
constexpr int simd_w = 16;
|
|
multiplier *= simd_w;
|
|
}
|
|
}
|
|
|
|
const dnnl_dim_t k_rounded = multiplier * div_up(prb->k, multiplier);
|
|
const dnnl_dims_t wei_packed_dims = {k_rounded * prb->batch_size, prb->n};
|
|
dims_t wei_packed_strides = {prb->get_ldb(), 1};
|
|
auto wei_packed_md = dnn_mem_t::init_md(
|
|
prb->ndims, wei_packed_dims, prb->wei_dt(), "", wei_packed_strides);
|
|
#endif
|
|
|
|
dnnl_dim_t scratchpad_size
|
|
= static_cast<dnnl_dim_t>(kernel_args.scratchpad_size_);
|
|
int ndims = scratchpad_size ? 1 : 0;
|
|
dnnl_data_type_t dt = scratchpad_size ? dnnl_u8 : dnnl_data_type_undef;
|
|
dnnl_dims_t scratchpad_dims = {scratchpad_size};
|
|
auto scratchpad_md
|
|
= dnn_mem_t::init_md(ndims, scratchpad_dims, dt, tag::abx);
|
|
|
|
const auto &test_engine = get_test_engine();
|
|
|
|
mem_map.emplace(
|
|
DNNL_ARG_SRC, dnn_mem_t(src_md, test_engine, /* prefill = */ true));
|
|
mem_map.emplace(DNNL_ARG_WEIGHTS,
|
|
dnn_mem_t(wei_md, test_engine, /* prefill = */ true));
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
if (prb->bia_dt != dnnl_data_type_undef) {
|
|
// Need condition to extract bias pointer based on presence in map
|
|
// condition.
|
|
mem_map.emplace(DNNL_ARG_BIAS,
|
|
dnn_mem_t(bia_md, test_engine, /* prefill = */ true));
|
|
}
|
|
#else
|
|
mem_map.emplace(DNNL_ARG_WEIGHTS_1,
|
|
dnn_mem_t(wei_packed_md, test_engine, /* prefill = */ true));
|
|
#endif
|
|
mem_map.emplace(DNNL_ARG_DST_1,
|
|
dnn_mem_t(acc_md, test_engine, /* prefill = */ true));
|
|
mem_map.emplace(
|
|
DNNL_ARG_DST, dnn_mem_t(dst_md, test_engine, /* prefill = */ true));
|
|
if (scratchpad_size > 0) {
|
|
// Need condition to extract scratchpad pointer based on presence in map
|
|
// condition.
|
|
mem_map.emplace(DNNL_ARG_SCRATCHPAD,
|
|
dnn_mem_t(scratchpad_md, test_engine, /* prefill = */ true));
|
|
}
|
|
|
|
// Binary post-op.
|
|
const auto &po = prb->attr.post_ops;
|
|
for (int idx = 0; idx < po.len(); ++idx) {
|
|
const auto &e = po.entry[idx];
|
|
if (!e.is_binary_kind()) continue;
|
|
|
|
int po_arg = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
|
|
const auto &b = e.binary;
|
|
int ndims = 2;
|
|
dims_t dims = prb->dst_dims;
|
|
|
|
using mask_input_t
|
|
= attr_t::post_ops_t::entry_t::binary_t::mask_input_t;
|
|
int mask = -1;
|
|
if (b.mask_input == mask_input_t::mask) {
|
|
mask = b.mask;
|
|
} else if (b.mask_input == mask_input_t::policy) {
|
|
mask = attr_t::policy2mask(po_arg, b.policy, dnnl_matmul, 2);
|
|
} else {
|
|
mask = attr_t::get_default_mask(b.policy);
|
|
}
|
|
|
|
switch (mask) {
|
|
case 0: dims = {1, 1}; break;
|
|
case 1: dims = {dims[0], 1}; break;
|
|
case 2: dims = {1, dims[1]}; break;
|
|
// Masks can be bigger than values above depending on the policy.
|
|
default: break;
|
|
}
|
|
|
|
auto po_md
|
|
= dnn_mem_t::init_md(ndims, dims.data(), b.src1_dt, tag::abx);
|
|
mem_map.emplace(
|
|
po_arg, dnn_mem_t(po_md, test_engine, /* prefill = */ true));
|
|
}
|
|
|
|
if (!prb->attr.scales.is_def()) {
|
|
const auto &sc = prb->attr.scales;
|
|
static const std::vector<int> supported_args {
|
|
DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
|
|
for (const auto &exec_arg : supported_args) {
|
|
if (sc.is_def(exec_arg)) continue;
|
|
|
|
const int exec_sc_arg = DNNL_ARG_ATTR_SCALES | exec_arg;
|
|
dims_t dims = {};
|
|
int64_t ndims = 1;
|
|
const auto mask = sc.get_mask(
|
|
exec_arg, dnnl_matmul, 2, /* has_groups = */ false);
|
|
|
|
if (mask > 0) {
|
|
const auto &md = mem_map.at(exec_arg).md_;
|
|
dims = md2dims(md, mask, false);
|
|
ndims = static_cast<int>(dims.size());
|
|
} else {
|
|
dims = {1};
|
|
ndims = 1;
|
|
}
|
|
const auto dt = sc.get(exec_arg).dt;
|
|
auto scales_md
|
|
= dnn_mem_t::init_md(ndims, dims.data(), dt, tag::abx);
|
|
mem_map.emplace(exec_sc_arg,
|
|
dnn_mem_t(scales_md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
|
|
if (!prb->attr.zero_points.is_def()) {
|
|
const auto &zp = prb->attr.zero_points;
|
|
static const std::vector<int> supported_args {
|
|
DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
|
|
for (const auto &exec_arg : supported_args) {
|
|
if (zp.is_def(exec_arg)) continue;
|
|
|
|
const int exec_zp_arg = DNNL_ARG_ATTR_ZERO_POINTS | exec_arg;
|
|
dims_t dims = {};
|
|
int64_t ndims = 1;
|
|
const auto mask
|
|
= zp.get_mask(exec_arg, dnnl_matmul, /* ndims = */ 2);
|
|
|
|
if (mask > 0) {
|
|
const auto &md = mem_map.at(exec_arg).md_;
|
|
dims = md2dims(md, mask, false);
|
|
ndims = static_cast<int>(dims.size());
|
|
} else {
|
|
dims = {1};
|
|
ndims = 1;
|
|
}
|
|
const auto dt = zp.get(exec_arg).dt;
|
|
auto zp_md = dnn_mem_t::init_md(ndims, dims.data(), dt, tag::abx);
|
|
mem_map.emplace(exec_zp_arg,
|
|
dnn_mem_t(zp_md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
}
|
|
|
|
int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
|
|
const prb_t *prb, const kernel_args_t &kernel_args, res_t *res) {
|
|
|
|
if (has_bench_mode_modifier(mode_modifier_t::no_ref_memory)) return OK;
|
|
|
|
const auto &ref_engine = get_cpu_engine();
|
|
|
|
// Move cfg out of filling since its creation is not free.
|
|
cfg_t cfg(prb, {SRC, WEI, BIA, DST});
|
|
|
|
const bool need_fill_acc
|
|
= (prb->beta != 0) || kernel_args.generate_skip_accumulation_;
|
|
|
|
for (auto &entry : mem_map) {
|
|
const int exec_arg = entry.first;
|
|
// The function targets regular exec_args that are positive.
|
|
// Negative args are used by bitwise and are broken in the `default`
|
|
// branch due to `&` always returns `true`.
|
|
if (exec_arg <= 0) continue;
|
|
|
|
auto &mem = entry.second; // `mem` is modified by filler (reorder).
|
|
|
|
// Scratchpad memory relates to a primitive. If reference needs it,
|
|
// use switch below to define a memory desc for it.
|
|
if (exec_arg != DNNL_ARG_SCRATCHPAD) {
|
|
ref_mem_map.emplace(exec_arg,
|
|
dnn_mem_t(mem.md_, dnnl_f32, tag::abx, ref_engine,
|
|
/* prefill = */ false));
|
|
}
|
|
|
|
auto &ref_mem = ref_mem_map[exec_arg];
|
|
|
|
switch (exec_arg) {
|
|
case DNNL_ARG_SRC:
|
|
SAFE(fill_data(SRC, prb, cfg, mem, ref_mem, res), WARN);
|
|
break;
|
|
case DNNL_ARG_WEIGHTS:
|
|
SAFE(fill_data(WEI, prb, cfg, mem, ref_mem, res), WARN);
|
|
break;
|
|
case DNNL_ARG_BIAS:
|
|
SAFE(fill_data(BIA, prb, cfg, mem, ref_mem, res), WARN);
|
|
break;
|
|
case DNNL_ARG_DST: {
|
|
const auto &po = prb->attr.post_ops;
|
|
const int sum_idx = po.find(attr_t::post_ops_t::SUM);
|
|
if (sum_idx >= 0) {
|
|
SAFE(fill_data(DST, prb, cfg, mem, ref_mem, res), WARN);
|
|
}
|
|
} break;
|
|
case DNNL_ARG_DST_1: {
|
|
if (need_fill_acc) {
|
|
SAFE(fill_data(DST, prb, cfg, mem, ref_mem, res), WARN);
|
|
}
|
|
} break;
|
|
default:
|
|
SAFE(init_ref_memory_args_default_case(
|
|
exec_arg, mem, ref_mem, prb->attr, res),
|
|
WARN);
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (need_fill_acc) {
|
|
// Beta requires same values for reference and the kernel.
|
|
if (prb->use_dst_as_acc()) {
|
|
auto &acc_fp = ref_mem_map.at(DNNL_ARG_DST_1);
|
|
auto &dst_fp = ref_mem_map.at(DNNL_ARG_DST);
|
|
auto &dst_dt = mem_map.at(DNNL_ARG_DST);
|
|
|
|
SAFE(dst_fp.reorder(acc_fp), WARN);
|
|
SAFE(dst_dt.reorder(dst_fp), WARN);
|
|
}
|
|
}
|
|
|
|
// A hack to pass brgemm attributes to reference execution since some
|
|
// members change the computation flow for correctness validation.
|
|
dnnl_dims_t dims = {1};
|
|
auto workspace_md = dnn_mem_t::init_md(1, dims, dnnl_u8, tag::abx);
|
|
ref_mem_map.emplace(DNNL_ARG_WORKSPACE,
|
|
dnn_mem_t(workspace_md, ref_engine, /* prefill = */ false,
|
|
{false, (void *)&kernel_args.generate_skip_accumulation_}));
|
|
ref_mem_map.at(DNNL_ARG_WORKSPACE).map();
|
|
|
|
return OK;
|
|
}
|
|
|
|
int scales_post_processing(dnn_mem_map_t &mem_map) {
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// Internal API has specific implementation details w.r.t. scales.
|
|
// If any of source or weights scales present in the descriptor, then the
|
|
// kernel expects to get a vector of 16 float values (v16) of "fused" scale
|
|
// values (src * wei) under the pointer passed in brgemm_post_ops_data_t
|
|
// struct.
|
|
// However, if weights scales is per channel, then the kernel expects just
|
|
// `N` values, even if `N < 16`.
|
|
// Same applies for a destination scale. Due to it's handled separately from
|
|
// source and weights, and must be a single value, it must be a v16 memory.
|
|
//
|
|
// To smoothly take care of this detail, the code below will **always**
|
|
// update WEIGHTS scale (even if they are not present) with a proper memory
|
|
// of 16 or N elements, depending on the case.
|
|
// It will update destination memory to contain 16 elements as well.
|
|
|
|
const bool has_src_scale
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
|
|
const bool has_wei_scale
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
|
|
const bool has_dst_scale
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
|
|
|
|
const auto replace_mem_to_v16 = [&](dnnl_data_type_t dt, int exec_arg,
|
|
float val) {
|
|
dims_t dims = {16};
|
|
auto new_md = dnn_mem_t::init_md(1, dims.data(), dt, tag::abx);
|
|
dnn_mem_t new_m(new_md, get_test_engine(), /* prefill = */ true);
|
|
if (!new_m.is_mapped()) new_m.map();
|
|
for (int64_t i = 0; i < new_m.nelems(); i++) {
|
|
new_m.set_elem(i, val);
|
|
}
|
|
mem_map[DNNL_ARG_ATTR_SCALES | exec_arg] = std::move(new_m);
|
|
};
|
|
|
|
if (has_wei_scale) {
|
|
const auto &wei_scales_m
|
|
= mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
|
|
// First, update the values...
|
|
if (has_src_scale) {
|
|
const auto &src_scales_m
|
|
= mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
|
|
assert(src_scales_m.nelems() == 1);
|
|
const float src_val = src_scales_m.get_elem(0);
|
|
for (int64_t i = 0; i < wei_scales_m.nelems(); i++) {
|
|
float val = wei_scales_m.get_elem(i) * src_val;
|
|
wei_scales_m.set_elem(i, val);
|
|
}
|
|
}
|
|
// Second, update memory for a single scale.
|
|
if (wei_scales_m.nelems() == 1) {
|
|
replace_mem_to_v16(wei_scales_m.dt(), DNNL_ARG_WEIGHTS,
|
|
wei_scales_m.get_elem(0));
|
|
}
|
|
} else if (has_src_scale) {
|
|
const auto &src_scales_m
|
|
= mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
|
|
assert(src_scales_m.nelems() == 1);
|
|
// Create a v16 weights scales memory and put src value there.
|
|
replace_mem_to_v16(
|
|
src_scales_m.dt(), DNNL_ARG_WEIGHTS, src_scales_m.get_elem(0));
|
|
}
|
|
|
|
if (has_dst_scale) {
|
|
const auto &dst_scales_m
|
|
= mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
|
|
assert(dst_scales_m.nelems() == 1);
|
|
// Create a v16 dst scales memory and bcast inversed dst value there.
|
|
replace_mem_to_v16(dst_scales_m.dt(), DNNL_ARG_DST,
|
|
1.f / dst_scales_m.get_elem(0));
|
|
}
|
|
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// ukernel API takes split pointers for scales, no need to update them on
|
|
// user level.
|
|
#endif
|
|
return OK;
|
|
}
|
|
|
|
int binary_post_op_preprocessing(
|
|
std::vector<const void *> &binary_po_v, const dnn_mem_map_t &mem_map) {
|
|
// Preprocessing must happen in two stages:
|
|
// 1. Collect all arguments values and sort them.
|
|
// 2. Insert memory pointers correspondent to arguments in order to satisfy
|
|
// the kernel expectations.
|
|
std::set<int> arg_vals;
|
|
for (const auto &map_entry : mem_map) {
|
|
const int exec_arg = map_entry.first;
|
|
|
|
const int post_ops_range = DNNL_ARG_ATTR_MULTIPLE_POST_OP(31)
|
|
- DNNL_ARG_ATTR_MULTIPLE_POST_OP(0);
|
|
const bool is_post_ops_arg = (exec_arg & post_ops_range);
|
|
if (!is_post_ops_arg) continue;
|
|
|
|
arg_vals.insert(exec_arg);
|
|
}
|
|
|
|
binary_po_v.reserve(arg_vals.size());
|
|
for (const auto &set_entry : arg_vals) {
|
|
void *handle = mem_map.at(set_entry).get_mapped_pointer<void>();
|
|
binary_po_v.push_back(handle);
|
|
}
|
|
|
|
return OK;
|
|
}
|
|
|
|
int init_hw_config(const kernel_args_t &kernel_args) {
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
#if defined(brg_x64)
|
|
if (kernel_args.need_tile_config_) {
|
|
DNN_SAFE(namespace_impl::amx_tile_configure(kernel_args.palette), WARN);
|
|
}
|
|
#endif
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
DNN_SAFE(dnnl_brgemm_set_hw_context(kernel_args.brgemm_), WARN);
|
|
#endif
|
|
return OK;
|
|
}
|
|
|
|
int release_hw_config(const kernel_args_t &kernel_args) {
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
#if defined(brg_x64)
|
|
if (kernel_args.need_tile_config_) {
|
|
DNN_SAFE(namespace_impl::amx_tile_release(), WARN);
|
|
}
|
|
#endif
|
|
#endif
|
|
return OK;
|
|
}
|
|
|
|
// `release_hw_config` and `brgemm_finalize` are doing the same thing -
|
|
// releasing the hw resources they allocated. The difference is in the
|
|
// implementation side - ukernel has lazy initialization and would reset the
|
|
// state per `set_hw_config` call while internal API doesn't - it just sets
|
|
// the new state unconditionally. Because of laziness, the test wants to ensure
|
|
// resetting is correct and, thus, releasing is done once - at the end of the
|
|
// suite, while for internal API it is done after each case.
|
|
int brgemm_finalize() {
|
|
#if defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
DNN_SAFE(dnnl_brgemm_release_hw_context(), WARN);
|
|
#endif
|
|
return OK;
|
|
}
|
|
|
|
int doit(const prb_t *prb, res_t *res) {
|
|
if (bench_mode == bench_mode_t::list) return res->state = LISTED, OK;
|
|
|
|
skip_start(res);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
// Need this here as brgemm has no primitive creation step
|
|
skip_invalid_prb(prb, res);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
kernel_args_t kernel_args(prb, res);
|
|
SAFE(init_kernel(kernel_args), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
if (bench_mode == bench_mode_t::init) return res->state = INITIALIZED, OK;
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
auto brgemm_kernel = make_benchdnn_dnnl_wrapper(kernel_args.brgemm_kernel_);
|
|
#else
|
|
auto brgemm = make_benchdnn_dnnl_wrapper(kernel_args.brgemm_);
|
|
auto transform = make_benchdnn_dnnl_wrapper(kernel_args.transform_);
|
|
#endif
|
|
|
|
dnn_mem_map_t mem_map, ref_mem_map;
|
|
init_memory_args(mem_map, prb, kernel_args);
|
|
TIME_FILL(SAFE(
|
|
init_ref_memory_args(ref_mem_map, mem_map, prb, kernel_args, res),
|
|
WARN));
|
|
|
|
// "Library" args are needed to get dst for comparison.
|
|
// "Reference" are used as usual.
|
|
args_t args(mem_map), ref_args(ref_mem_map);
|
|
|
|
// The implementation memory must be mapped to setup point arguments for
|
|
// brgemm implementation call. This assumes that mapping is effectively a
|
|
// no-op on the target device.
|
|
for (auto &kv : mem_map) {
|
|
if (!kv.second.is_mapped()) kv.second.map();
|
|
}
|
|
|
|
const char *src_ptr = (const char *)mem_map.at(DNNL_ARG_SRC);
|
|
const char *wei_ptr = (const char *)mem_map.at(DNNL_ARG_WEIGHTS);
|
|
char *acc_ptr = (char *)mem_map.at(DNNL_ARG_DST_1);
|
|
char *dst_ptr = (char *)mem_map.at(DNNL_ARG_DST);
|
|
if (prb->use_dst_as_acc()) acc_ptr = dst_ptr;
|
|
|
|
SAFE(scales_post_processing(mem_map), WARN);
|
|
|
|
std::vector<const void *> binary_po_v;
|
|
SAFE(binary_post_op_preprocessing(binary_po_v, mem_map), WARN);
|
|
|
|
const float *dst_scales_ptr
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)
|
|
? (const float *)mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST)
|
|
: nullptr;
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
std::vector<namespace_impl::brgemm_batch_element_t> v_batch_element(
|
|
prb->batch_size);
|
|
for (size_t i = 0; i < v_batch_element.size(); i++) {
|
|
if (prb->batch_kind == "addr") {
|
|
v_batch_element[i].ptr.A
|
|
= src_ptr + i * prb->get_src_batch_offset();
|
|
v_batch_element[i].ptr.B
|
|
= wei_ptr + i * prb->get_wei_batch_offset();
|
|
} else if (prb->batch_kind == "offs") {
|
|
v_batch_element[i].offset.A = i * prb->get_src_batch_offset();
|
|
v_batch_element[i].offset.B = i * prb->get_wei_batch_offset();
|
|
}
|
|
}
|
|
|
|
// For internal API, scales are combined. See `scales_post_processing`.
|
|
const float *scales_ptr
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
|
|
? (const float *)mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
|
|
: nullptr;
|
|
const int32_t *dst_zp_ptr
|
|
= mem_map.count(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST)
|
|
? (const int32_t *)mem_map.at(
|
|
DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST)
|
|
: nullptr;
|
|
int32_t zp_a_val = mem_map.count(DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC)
|
|
? *(const int32_t *)mem_map.at(
|
|
DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC)
|
|
: 0;
|
|
const char *bia_dt_ptr = mem_map.count(DNNL_ARG_BIAS)
|
|
? (const char *)mem_map.at(DNNL_ARG_BIAS)
|
|
: nullptr;
|
|
|
|
const auto &wei_md = mem_map.at(DNNL_ARG_WEIGHTS).md_;
|
|
const auto &wei_md_extra = static_cast<dnnl_memory_desc_t>(wei_md)->extra;
|
|
// This relies on an internal knowledge of how compensation is implemented.
|
|
const size_t wei_offset_s8s8 = kernel_args.original_wei_md_size_;
|
|
const size_t wei_offset_zp = wei_offset_s8s8
|
|
+ ((wei_md_extra.flags
|
|
& dnnl::impl::memory_extra_flags::compensation_conv_s8s8)
|
|
? prb->get_ldb() * sizeof(int32_t)
|
|
: 0);
|
|
char *src_comp_ptr = const_cast<char *>(wei_ptr) + wei_offset_zp;
|
|
|
|
namespace_impl::brgemm_post_ops_data_t post_ops_data(
|
|
/* bias */ bia_dt_ptr,
|
|
/* scales */ scales_ptr,
|
|
/* binary_post_ops_rhs */ binary_po_v.data(),
|
|
/* oc_logical_off */ 0, /* dst_row_logical_off */ 0,
|
|
// TODO: though the field is called `data_C_ptr_`, this is a
|
|
// misleading name since actually dst_ptr must be used there to
|
|
// have binary injector working for per_tensor policy.
|
|
/* data_C_ptr_ */ dst_ptr, /* first_mb_matrix_addr_off */ 0,
|
|
/* a_zp_compensations */ src_comp_ptr,
|
|
/* b_zp_compensations */ nullptr,
|
|
/* c_zp_values */ dst_zp_ptr,
|
|
/* skip_accumulation */
|
|
kernel_args.generate_skip_accumulation_,
|
|
/* zp_a_val */ zp_a_val,
|
|
/* do_only_comp */ false,
|
|
/* do_only_zp_a_val */ false,
|
|
/* dst_scales */ dst_scales_ptr);
|
|
|
|
// Note: hardware lacking native s8s8 support expects compensation buffer
|
|
// passed through a scratchpad argument in postops execution call.
|
|
const bool has_scratchpad = mem_map.count(DNNL_ARG_SCRATCHPAD);
|
|
const bool need_hidden_compensation = !has_scratchpad
|
|
&& prb->get_dt(SRC) == dnnl_s8 && prb->get_dt(WEI) == dnnl_s8;
|
|
char *scratchpad_ptr = need_hidden_compensation
|
|
? (const_cast<char *>(wei_ptr) + wei_offset_s8s8)
|
|
: has_scratchpad ? (char *)mem_map.at(DNNL_ARG_SCRATCHPAD)
|
|
: nullptr;
|
|
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
char *wei_packed_ptr = (char *)mem_map.at(DNNL_ARG_WEIGHTS_1);
|
|
|
|
char *scratchpad_ptr = mem_map.count(DNNL_ARG_SCRATCHPAD)
|
|
? (char *)mem_map.at(DNNL_ARG_SCRATCHPAD)
|
|
: nullptr;
|
|
|
|
if (kernel_args.need_pack_) {
|
|
DNN_SAFE(dnnl_transform_execute(transform, wei_ptr, wei_packed_ptr),
|
|
WARN);
|
|
} else {
|
|
const auto &wei_dt = mem_map.at(DNNL_ARG_WEIGHTS);
|
|
auto &wei_packed_dt = mem_map.at(DNNL_ARG_WEIGHTS_1);
|
|
SAFE(wei_packed_dt.reorder(wei_dt), WARN);
|
|
}
|
|
|
|
std::vector<dnnl_dim_t> offsets(2 * prb->batch_size);
|
|
for (dnnl_dim_t i = 0; i < prb->batch_size; i++) {
|
|
offsets[2 * i + 0] = i * prb->get_src_batch_offset();
|
|
offsets[2 * i + 1] = i * prb->get_wei_batch_offset();
|
|
}
|
|
|
|
dnnl_ukernel_attr_params_t attr_params_ptr;
|
|
DNN_SAFE(dnnl_ukernel_attr_params_create(&attr_params_ptr), WARN);
|
|
auto attr_params = make_benchdnn_dnnl_wrapper(attr_params_ptr);
|
|
DNN_SAFE(dnnl_ukernel_attr_params_set_post_ops_args(
|
|
attr_params, binary_po_v.data()),
|
|
WARN);
|
|
|
|
const void *src_scales_ptr
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)
|
|
? (const void *)mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC)
|
|
: nullptr;
|
|
const void *wei_scales_ptr
|
|
= mem_map.count(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
|
|
? (const void *)mem_map.at(DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS)
|
|
: nullptr;
|
|
DNN_SAFE(dnnl_ukernel_attr_params_set_A_scales(attr_params, src_scales_ptr),
|
|
WARN);
|
|
DNN_SAFE(dnnl_ukernel_attr_params_set_B_scales(attr_params, wei_scales_ptr),
|
|
WARN);
|
|
DNN_SAFE(dnnl_ukernel_attr_params_set_D_scales(attr_params, dst_scales_ptr),
|
|
WARN);
|
|
#endif
|
|
|
|
SAFE(init_hw_config(kernel_args), WARN);
|
|
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
if (prb->batch_kind == "addr") {
|
|
brgemm_kernel_execute_postops(brgemm_kernel, prb->batch_size,
|
|
v_batch_element.data(), acc_ptr, dst_ptr, post_ops_data,
|
|
scratchpad_ptr);
|
|
} else if (prb->batch_kind == "offs") {
|
|
brgemm_kernel_execute_postops(brgemm_kernel, prb->batch_size, src_ptr,
|
|
wei_ptr, v_batch_element.data(), acc_ptr, dst_ptr,
|
|
post_ops_data, scratchpad_ptr);
|
|
}
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
// `prb->use_dst_as_acc()=true` will make `dst_ptr=acc_ptr` and rest should
|
|
// be handled by API.
|
|
DNN_SAFE(dnnl_brgemm_execute_postops(brgemm, src_ptr, wei_packed_ptr,
|
|
offsets.data(), acc_ptr, dst_ptr, scratchpad_ptr,
|
|
attr_params),
|
|
WARN);
|
|
#endif
|
|
res->state = EXECUTED;
|
|
|
|
if (has_bench_mode_bit(mode_bit_t::corr)) {
|
|
check_correctness(prb, {DST}, args, ref_args, setup_cmp, res, prb->dir);
|
|
}
|
|
|
|
// Create a bind to match internals to run performance measurements.
|
|
#if !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
perf_function_t perf_func = std::bind(brgemm_kernel_execute_postops_wrapper,
|
|
kernel_args.brgemm_kernel_, prb->batch_kind, prb->batch_size,
|
|
src_ptr, wei_ptr, v_batch_element.data(), acc_ptr, dst_ptr,
|
|
post_ops_data, scratchpad_ptr, std::placeholders::_1,
|
|
std::placeholders::_2);
|
|
#else // !defined(DNNL_EXPERIMENTAL_UKERNEL)
|
|
perf_function_t perf_func = std::bind(brgemm_kernel_execute_postops_wrapper,
|
|
kernel_args.brgemm_, prb->use_dst_as_acc(), src_ptr, wei_packed_ptr,
|
|
offsets, acc_ptr, dst_ptr, scratchpad_ptr, attr_params_ptr,
|
|
std::placeholders::_1, std::placeholders::_2);
|
|
#endif
|
|
|
|
measure_perf(prb->ctx_exe, res, perf_func, args);
|
|
|
|
SAFE(release_hw_config(kernel_args), WARN);
|
|
|
|
return OK;
|
|
}
|
|
|
|
#else
|
|
|
|
// For builadability of non-x64 configuration.
|
|
int brgemm_finalize() {
|
|
return OK;
|
|
}
|
|
|
|
int doit(const prb_t *prb, res_t *res) {
|
|
return OK;
|
|
}
|
|
|
|
#endif
|
|
|
|
} // namespace brgemm
|