Files
oneDNN/tests/benchdnn/bnorm/bnorm.cpp
2025-05-28 15:07:11 -07:00

775 lines
29 KiB
C++

/*******************************************************************************
* Copyright 2017-2025 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#include <float.h>
#include <math.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <random>
#include <sstream>
#include "oneapi/dnnl/dnnl.h"
#include "utils/fill.hpp"
#include "utils/parallel.hpp"
#include "dnnl_common.hpp"
#include "dnnl_memory.hpp"
#include "bnorm/bnorm.hpp"
namespace bnorm {
int fill_mean(const prb_t *prb, const cfg_t &cfg, dnn_mem_t &mem_fp,
dnn_mem_t &mem_dt) {
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
// Mean must be computed unless it is passed by user directly.
if (!(prb->flags & GLOB_STATS)) return OK;
return fill_random_real(mem_dt, mem_fp, nullptr);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
float val = 0.f;
if (cfg.check_alg_ != ALG_0 || (prb->flags & GLOB_STATS))
val = 0.25f * (1 << (c % 7));
mem_fp.set_f32_elem(c, val);
});
if (mem_dt && prb->use_stats()) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int fill_src(const prb_t *prb, const cfg_t &cfg, dnn_mem_t &mem_fp,
dnn_mem_t &mem_dt, const dnn_mem_t &ref_mean, res_t *res) {
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
return fill_random_real(mem_dt, mem_fp, res);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, res, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
const float m = ref_mean.get_f32_elem(c);
for (int64_t mb = 0; mb < prb->mb; ++mb) {
// l[0] must be even
int64_t l_base = mb * prb->id * prb->ih * prb->iw + c * 239 * 2;
int64_t off = data_off(prb, mb, c, 0, 0, 0);
for_(int64_t d = 0; d < prb->id; ++d)
for_(int64_t h = 0; h < prb->ih; ++h)
for (int64_t w = 0; w < prb->iw; ++w) {
const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w;
const int64_t l = l_base + sp;
// Shortcut for zero values.
if (cfg.check_alg_ == ALG_0
&& !flip_coin(l / 2 * 257ULL, cfg.density_)) {
mem_fp.set_f32_elem(off + sp, 0);
continue;
}
const int64_t gen = (l / 2 * 1637) & cfg.flex_mask_;
const int sgn = l % 2 == 0 ? 1 : -1; // s_{i} + s_{i+1} = 2 * m
const float f = 1.f * sgn * gen / (1 << cfg.flex_bits_);
float val = cfg.check_alg_ == ALG_0 ? f : m * (1.f + f);
if (prb->flags & GLOB_STATS) {
val = (l % 65) - 32;
} else if (cfg.L_ % 2
&& (mb * prb->id * prb->ih * prb->iw + sp
== cfg.L_ - 1)) {
val = m;
}
mem_fp.set_f32_elem(
off + sp, round_to_nearest_representable(prb->dt, val));
}
}
});
if (mem_dt) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int fill_variance(const prb_t *prb, const cfg_t &cfg, dnn_mem_t &mem_fp,
dnn_mem_t &mem_dt, const dnn_mem_t &ref_src,
const dnn_mem_t &ref_mean) {
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
// Variance must be computed unless it is passed by user directly.
if (!(prb->flags & GLOB_STATS)) return OK;
// Variance must be always positive by definition.
fill_cfg_t fill_cfg(dnnl_f32, 0.f, 16.f, /* int = */ false,
attr_t::post_ops_t::kind_t::ADD, "variance");
return fill_random_real(mem_dt, mem_fp, nullptr, fill_cfg);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
float val = 0.f;
if (prb->flags & GLOB_STATS) {
val = ((c % 7) << 1);
} else {
const float m = ref_mean.get_f32_elem(c);
for (int64_t mb = 0; mb < prb->mb; ++mb) {
int64_t off = data_off(prb, mb, c, 0, 0, 0);
for_(int64_t d = 0; d < prb->id; ++d)
for_(int64_t h = 0; h < prb->ih; ++h)
for (int64_t w = 0; w < prb->iw; ++w) {
const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w;
const float s = ref_src.get_f32_elem(sp + off);
val += (s - m) * (s - m);
}
}
val /= cfg.L_;
}
mem_fp.set_f32_elem(c, val);
});
if (mem_dt && prb->use_stats()) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int fill_src_add(const prb_t *prb, const cfg_t &cfg, dnn_mem_t &mem_fp,
dnn_mem_t &mem_dt, const dnn_mem_t &ref_src,
const dnn_mem_t &ref_mean) {
const bool fill_src_add = prb->fuse_add_relu();
if (!fill_src_add) return OK;
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
return fill_random_real(mem_dt, mem_fp, nullptr);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, prb->mb, prb->id, prb->ih, prb->iw,
[&](int64_t c, int64_t mb, int64_t d, int64_t h, int64_t w) {
const int64_t l_base
= mb * prb->id * prb->ih * prb->iw + c * 239 * 2;
const int64_t sp = d * prb->ih * prb->iw + h * prb->iw + w;
const int64_t offset = data_off(prb, mb, c, 0, 0, 0) + sp;
const int64_t l = l_base + sp;
const float s = ref_src.get_f32_elem(offset);
const float m = ref_mean.get_f32_elem(c);
if (!(prb->flags & GLOB_STATS) && s == 0) {
mem_fp.set_f32_elem(offset, 1.f);
return;
}
float val = 0.f;
if (prb->flags & GLOB_STATS) {
val = (l % 17) - 8;
} else {
// The main purpose of such filling is to avoid catastrophic
// cancellation. To do that, the sign of `Add` tensor final
// values is kept the same as it would be after applying
// bnorm: what's below mean, that has negative sign, what's
// equal or higher - positive.
const int64_t mod2_base = (mb + c + d + h + w) % 5;
const float mod2_val = 1.f / (2LL << mod2_base);
const int64_t sign_val = s < m ? -1 : 1;
val = mod2_val * sign_val;
}
mem_fp.set_f32_elem(
offset, round_to_nearest_representable(prb->dt, val));
});
if (mem_dt) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int fill_scale(const prb_t *prb, dnn_mem_t &mem_fp, dnn_mem_t &mem_dt) {
const bool use_sc = prb->use_sc();
if (!use_sc) return OK;
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
return fill_random_real(mem_dt, mem_fp, nullptr);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
float val = (1.f / 8) * (1 << (c % 7));
if (prb->flags & GLOB_STATS) val *= 8.f;
mem_fp.set_f32_elem(c, val);
});
if (mem_dt) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int fill_shift(const prb_t *prb, dnn_mem_t &mem_fp, dnn_mem_t &mem_dt) {
const bool use_sh = prb->use_sh();
if (!use_sh) return OK;
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
return fill_random_real(mem_dt, mem_fp, nullptr);
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
benchdnn_parallel_nd(prb->ic, [&](int64_t c) {
float val = ((c % 3) - 1) * (1.f / 512 * (1 << (c % 7)));
if (prb->flags & GLOB_STATS) val *= 512.f;
mem_fp.set_f32_elem(c, val);
});
if (mem_dt) SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int prepare_fwd(const prb_t *prb, dnn_mem_map_t &mem_map,
dnn_mem_map_t &ref_mem_map, res_t *res) {
cfg_t cfg(prb);
auto &mean = mem_map.at(DNNL_ARG_MEAN);
auto &ref_mean = ref_mem_map.at(DNNL_ARG_MEAN);
SAFE(fill_mean(prb, cfg, ref_mean, mean), WARN);
auto &src = mem_map.at(DNNL_ARG_SRC);
auto &ref_src = ref_mem_map.at(DNNL_ARG_SRC);
SAFE(fill_src(prb, cfg, ref_src, src, ref_mean, res), WARN);
// Need a copy of source data for inplace mode for bitwise testing.
if (has_bench_mode_bit(mode_bit_t::bitwise) && prb->inplace) {
auto &src_copy = mem_map.at(-DNNL_ARG_SRC);
SAFE(bool(src_copy) ? OK : FAIL, WARN);
SAFE(src_copy.reorder(src), WARN);
}
auto &var = mem_map.at(DNNL_ARG_VARIANCE);
auto &ref_var = ref_mem_map.at(DNNL_ARG_VARIANCE);
SAFE(fill_variance(prb, cfg, ref_var, var, ref_src, ref_mean), WARN);
auto &src_add = mem_map.at(DNNL_ARG_SRC_1);
auto &ref_src_add = ref_mem_map.at(DNNL_ARG_SRC_1);
SAFE(fill_src_add(prb, cfg, ref_src_add, src_add, ref_src, ref_mean), WARN);
auto &scale = mem_map.at(DNNL_ARG_SCALE);
auto &ref_scale = ref_mem_map.at(DNNL_ARG_SCALE);
SAFE(fill_scale(prb, ref_scale, scale), WARN);
auto &shift = mem_map.at(DNNL_ARG_SHIFT);
auto &ref_shift = ref_mem_map.at(DNNL_ARG_SHIFT);
SAFE(fill_shift(prb, ref_shift, shift), WARN);
return OK;
}
int prepare_bwd(
const prb_t *prb, dnn_mem_map_t &mem_map, dnn_mem_map_t &ref_mem_map) {
auto &mem_fp = ref_mem_map.at(DNNL_ARG_DIFF_DST);
auto &mem_dt = mem_map.at(DNNL_ARG_DIFF_DST);
const auto nelems = mem_fp.nelems();
if (nelems == 0) return OK;
// Refer to modes documentation for filling principles.
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
SAFE(fill_random_real(mem_dt, mem_fp, nullptr), WARN);
// Need a copy of source data for inplace mode for bitwise testing.
if (prb->inplace) {
auto &diff_dst_copy = mem_map.at(-DNNL_ARG_DIFF_DST);
SAFE(bool(diff_dst_copy) ? OK : FAIL, WARN);
SAFE(diff_dst_copy.reorder(mem_dt), WARN);
}
return OK;
}
if (has_bench_mode_bit(mode_bit_t::perf)) {
return fill_random_real(
mem_dt, mem_fp, nullptr, get_perf_fill_cfg(mem_dt.dt()));
}
// Idea behind filling: integer diff_dst values decrease norms unlike fp32
// values in [-1.f, 1.f] range. To decrease norms more, make data pretty
// sparse as answers sum all diff_dst values.
/* Do fixed partitioning to have same filling for any number of threads */
const int64_t n_chunks = 16;
const int64_t chunk_size = div_up(nelems, n_chunks);
benchdnn_parallel_nd(n_chunks, [&](int64_t idx_chunk) {
int64_t idx_start = idx_chunk * chunk_size;
int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
// Note: we use a different seed for each chunk to avoid
// repeating patterns. We could use discard(idx_start) too but
// it has a complexity in O(idx_start). We also add 1 to avoid
// seeding with 0.
std::minstd_rand msr(idx_start + 1);
msr.discard(1);
std::uniform_int_distribution<> igen_val(-2, 2);
std::uniform_int_distribution<> igen_coin(0, 256 * 1024);
// at least 20 non-zero elems
float sparsity = MAX2(0.05f, MIN2(1.f, 20.f / nelems));
for (int64_t idx = idx_start; idx < idx_end; ++idx) {
float value = flip_coin(igen_coin(msr), sparsity)
? round_to_nearest_representable(prb->dt, igen_val(msr))
: 0;
mem_fp.set_f32_elem(idx, value);
}
});
SAFE(mem_dt.reorder(mem_fp), WARN);
return OK;
}
int check_fwd_ws(dnn_mem_map_t &mem_map, res_t *res) {
const auto &dst_dt = mem_map.at(DNNL_ARG_DST);
const auto &ws_dt = mem_map.at(DNNL_ARG_WORKSPACE);
if (ws_dt.ndims() == 0) return OK;
/* so far we know ws is just bit-mask of whether value was negative or
* positive */
const auto nelems = dst_dt.nelems(true);
const uint8_t *ws = (const uint8_t *)ws_dt;
/* some internal knowledge: flags in ws are either stored as bytes (e.g.
* for the ref implementation) or as bits (e.g. for the jitted one); in
* the latter case the ws memory has fewer elements than the data memory */
enum { ws_byte, ws_bit } ws_type;
ws_type = ws_dt.nelems(true) < nelems ? ws_bit : ws_byte;
/* more internal knowledge: dst_dt and ws_dt are expected to have exactly
* the same data layout, and dst_dt padded regions are expected to be
* zero, and the respective ws_dt elements should be set accordingly */
for (int64_t i = 0; i < nelems; i += 8) {
for (int64_t j = 0; j < MIN2(8, nelems - i); ++j) {
const float data = dst_dt.get_f32_elem(i + j);
const bool want = data > 0.f;
const bool bit_set = ws_type == ws_byte ? *ws : !!(*ws & (1 << j));
const bool ok = bit_set == want;
res->errors += !ok;
bool dump = false || (!ok && (res->errors < 10 || verbose >= 10))
|| (verbose >= 50 && i < 30);
if (dump) {
BENCHDNN_PRINT(0, "[%4ld] ws exp:%d got:%d (data:%g:%a)\n",
(long)(i + j), want, bit_set, data, data);
}
if (ws_type == ws_byte) ++ws;
}
if (ws_type == ws_bit) ++ws;
}
if (res->errors) res->state = FAILED;
if (res->state == EXECUTED) res->state = PASSED;
return res->state == FAILED ? FAIL : OK;
}
dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
const prb_t *prb = init_pd_args.prb;
const dir_t dir = init_pd_args.dir;
res_t *res = init_pd_args.res;
bool force_f32_dt = init_pd_args.force_f32_dt;
auto src_d = dnn_mem_t::init_md(prb->ndims, prb->data_dims().data(),
force_f32_dt ? dnnl_f32 : prb->dt, prb->tag, prb->strides[0]);
auto dnnl_attr = make_benchdnn_dnnl_wrapper(
create_dnnl_attr(prb->attr, attr_args_t()));
auto flags = (dnnl_normalization_flags_t)prb->flags;
if (dir & FLAG_FWD) {
auto dst_d = dnn_mem_t::init_md(prb->ndims, prb->data_dims().data(),
force_f32_dt ? dnnl_f32 : prb->dt, tag::any, prb->strides[1]);
auto prop = prb->dir & FLAG_INF ? dnnl_forward_inference
: dnnl_forward_training;
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_batch_normalization_forward_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine, prop,
init_pd_args.src_md ? init_pd_args.src_md : src_d,
dst_d, prb->eps, flags, dnnl_attr)));
} else {
auto diff_src_d = dnn_mem_t::init_md(prb->ndims,
prb->data_dims().data(), force_f32_dt ? dnnl_f32 : prb->dt,
tag::any, prb->strides[0]);
auto diff_dst_d = dnn_mem_t::init_md(prb->ndims,
prb->data_dims().data(), force_f32_dt ? dnnl_f32 : prb->dt,
tag::any, prb->strides[1]);
auto prop = prb->dir & FLAG_WEI ? dnnl_backward : dnnl_backward_data;
TIME_C_PD(DNN_SAFE_STATUS(
dnnl_batch_normalization_backward_primitive_desc_create(
&init_pd_args.pd, init_pd_args.engine, prop, diff_src_d,
diff_dst_d, src_d, prb->eps, flags, init_pd_args.hint,
dnnl_attr)));
}
return dnnl_success;
}
void skip_unimplemented_prb(const prb_t *prb, res_t *res) {
skip_unimplemented_data_type({prb->dt}, prb->dir, res);
skip_unimplemented_sum_po(
prb->attr, res, dnnl_batch_normalization, prb->dt);
skip_unimplemented_binary_po(prb->attr, res);
skip_unimplemented_prelu_po(prb->attr, res, dnnl_batch_normalization);
// Non-zero alpha is not supported for training in general.
const auto &po = prb->attr.post_ops;
const auto relu_idx = po.find(attr_t::post_ops_t::kind_t::RELU);
if (relu_idx >= 0) {
const auto &e = po.entry[relu_idx];
float alpha = e.eltwise.alpha;
bool alpha_ok = IMPLICATION(alpha != 0.f, (prb->dir & FLAG_INF));
if (!alpha_ok) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
}
}
// BN+Add+ReLU fusion is not supported on CPU
if (is_cpu() && prb->fuse_add_relu()) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
}
// int8 only supports forward s8 w/ global stats
const bool u8_not_ok = prb->dt == dnnl_u8;
const bool s8_not_ok = prb->dt == dnnl_s8
&& ((prb->dir & FLAG_BWD) || (prb->flags & GLOB_STATS) == 0);
if (s8_not_ok || u8_not_ok) {
res->state = SKIPPED;
res->reason = skip_reason::case_not_supported;
}
}
void skip_invalid_prb(const prb_t *prb, res_t *res) {
// MIOpen cannot handle inplace cases correctly.
if (is_amd_gpu() && prb->inplace) {
res->state = SKIPPED;
return;
}
// See `skip_invalid_inplace` for details.
if (prb->inplace) {
skip_invalid_inplace(res, prb->dt, prb->dt, prb->tag, prb->tag);
if (res->state == SKIPPED) return;
}
}
void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
const args_t &ref_args) {
// Since bwd testing is done using results from forward which are random
// fp32 values, diff_scale starts fluctuating, so we check norm for both
// data, SC, and SH.
const bool compare_with_norm = (prb->dir & FLAG_BWD);
cmp.set_norm_validation_mode(compare_with_norm);
// Digits must be non-negative for safe left-shifting when `digits_dt`
// exceeds `digits_f32`.
const int safe_digits = MAX2(0, digits_dt(dnnl_f32) - digits_dt(prb->dt));
const float trh_coeff = (1 << safe_digits);
float trh = trh_coeff
* ((kind == SRC || kind == DST || kind == SRC_1) ? 6e-7f : 0);
if ((kind == SC || kind == SH) && prb->dir & FLAG_BWD)
trh = trh_coeff * 5e-6f;
#ifdef DNNL_EXPERIMENTAL
const bool bnorm_single_pass
= dnnl::impl::experimental::use_bnorm_stats_one_pass();
#else
const bool bnorm_single_pass = false;
#endif
const bool use_relaxed_validation
= is_nvidia_gpu() || is_amd_gpu() || bnorm_single_pass;
if (use_relaxed_validation) {
// Nvidia (cuDNN) and AMD (MIOpen): store unbiased variance which
// requires rescaling by `(N - 1) / N`, where `N = MB * Spatial`.
// Hence, we cannot set the threshold to 0...
// Also mean could be computed using a single pass formula.
//
// On Intel GPUs mean and variance could be rounded incorrectly because
// they are calculated using fast but potentially unstable formula.
if (kind == MEAN) trh = 1e-7f;
if (kind == VAR) trh = 5e-7f;
if (kind == DST) trh = 2e-6f;
}
cmp.set_threshold(trh);
// TODO: improve bf16 filling
if (prb->dt == dnnl_bf16) cmp.set_zero_trust_percent(99.f);
// When the error is larger than `trh`, it could be due to a catastrophic
// cancellation in final result which is computed as `Y = a * X + b`.
// When `a * X` is close to `b` and their signs are opposite, then large
// error in `a * X` could result in a final result (which has a cancellation
// i.e. `|Y| = |a*X - (-b)|`), which has no meaningful digits left in
// mantissa.
//
// Since lambda is called when stack is unavailable, need to capture `prb`
// and `kind` by value to avoid using dangling references.
const auto bnorm_add_check =
[&, kind, prb](
const compare::compare_t::driver_check_func_args_t &args) {
bool ok = is_nvidia_gpu() && args.diff < args.trh;
if (ok) return true;
if (!((prb->dir & FLAG_FWD) && kind == DST && prb->use_sh()))
return false;
const auto &sh = ref_args.find(DNNL_ARG_SHIFT);
const auto &dst = ref_args.find(DNNL_ARG_DST);
const int64_t c
= dst.get_idx(args.idx, 1 << 1 /* channel_mask */);
const float beta = sh.get_f32_elem(c);
// Using an empirically derived threshold, check if
// cancellation error in `|Y| = |a*X - (-b)|` is huge.
const float abs_exp = fabsf(args.exp);
const float norm_denom = abs_exp > FLT_MIN ? abs_exp : 1.f;
const float abs_exp_delta = fabsf(args.exp - beta);
bool maybe_cancel_error = abs_exp_delta / norm_denom > 1.f;
if (!maybe_cancel_error) return false;
// Check for error in `a * X`
float diff_aX = fabsf((args.exp - beta) - (args.got - beta));
float rel_diff_aX = diff_aX
/ (abs_exp_delta > FLT_MIN ? abs_exp_delta : 1.f);
return rel_diff_aX <= args.trh;
};
cmp.set_driver_check_function(bnorm_add_check);
}
std::vector<int> supported_exec_args(dir_t dir) {
static const std::vector<int> exec_fwd_args = {
DNNL_ARG_SRC,
DNNL_ARG_SRC_1,
DNNL_ARG_MEAN,
DNNL_ARG_VARIANCE,
DNNL_ARG_SCALE,
DNNL_ARG_SHIFT,
DNNL_ARG_DST,
DNNL_ARG_WORKSPACE,
};
static const std::vector<int> exec_bwd_args = {
DNNL_ARG_SRC,
DNNL_ARG_SRC_1,
DNNL_ARG_MEAN,
DNNL_ARG_VARIANCE,
DNNL_ARG_SCALE,
DNNL_ARG_WORKSPACE,
DNNL_ARG_DIFF_DST,
DNNL_ARG_DIFF_SCALE,
DNNL_ARG_DIFF_SHIFT,
DNNL_ARG_DIFF_SRC,
DNNL_ARG_DIFF_SRC_1,
};
return (dir & FLAG_FWD) ? exec_fwd_args : exec_bwd_args;
};
int init_ref_memory_args(dnn_mem_map_t &ref_mem_map, dnn_mem_map_t &mem_map,
dnnl_primitive_t prim, const prb_t *prb, res_t *res,
dnnl_primitive_t prim_ref) {
if (has_bench_mode_modifier(mode_modifier_t::no_ref_memory)) return OK;
if (!ref_mem_map.empty()) { erase_unused_args(ref_mem_map, mem_map); }
// TODO: this function still allocates the full memory print needed to fill
// the data and each argument can't be destroyed right away since filling
// requires all of them at a time.
const auto &ref_engine = get_cpu_engine();
const bool is_fwd_prim = is_fwd_prop_kind(query_prop_kind(query_pd(prim)));
for (auto &entry : mem_map) {
const int exec_arg = entry.first;
// The function targets regular exec_args that are positive.
// Negative args are used by bitwise and are broken in the `default`
// branch due to `&` always returns `true`.
if (exec_arg <= 0) continue;
auto &mem = entry.second;
// Scratchpad memory relates to a primitive. If reference needs it,
// use switch below to define a memory desc for it.
if (exec_arg != DNNL_ARG_SCRATCHPAD && exec_arg != DNNL_ARG_WORKSPACE) {
ref_mem_map.emplace(exec_arg,
dnn_mem_t(mem.md_, dnnl_f32, tag::abx, ref_engine,
/* prefill = */ false));
}
switch (exec_arg) {
case DNNL_ARG_DST:
case DNNL_ARG_DIFF_SRC: break; // Skip on backward.
case DNNL_ARG_MEAN:
case DNNL_ARG_VARIANCE:
if (prb->dir & FLAG_INF) {
const dnnl_dims_t dims1d = {prb->ic};
ref_mem_map[exec_arg] = dnn_mem_t(1, dims1d, dnnl_f32,
tag::abx, ref_engine, /* prefill = */ false);
}
break;
default: break;
}
}
if (is_fwd_prim) {
SAFE(prepare_fwd(prb, mem_map, ref_mem_map, res), WARN);
} else {
SAFE(prepare_bwd(prb, mem_map, ref_mem_map), WARN);
}
// Don't keep reference memory if it is not used further.
if (!has_bench_mode_bit(mode_bit_t::corr)) {
ref_mem_map.clear();
return OK;
}
// Reference code uses different kind of workspace. Adjust to ref needs.
// Insert a reference workspace if it comes non-empty from the library side.
if (query_md_ndims(mem_map.at(DNNL_ARG_WORKSPACE).md_) > 0 && is_fwd_prim) {
const auto &src_md = ref_mem_map[DNNL_ARG_SRC].md_;
ref_mem_map[DNNL_ARG_WORKSPACE] = dnn_mem_t(
src_md, dnnl_u8, tag::abx, ref_engine, /* prefill = */ false);
}
return OK;
}
std::vector<data_kind_t> get_kinds_to_check(const prb_t *prb, dir_t dir) {
std::vector<data_kind_t> check_kinds;
if ((prb->dir & FLAG_FWD) && (dir & FLAG_FWD)) {
if (!(prb->flags & GLOB_STATS) && !(prb->dir & FLAG_INF)) {
check_kinds.push_back(MEAN);
check_kinds.push_back(VAR);
}
check_kinds.push_back(DST);
} else if ((prb->dir & FLAG_BWD) && (dir & FLAG_BWD)) {
if (prb->dir & FLAG_WEI) {
if (prb->use_sc()) check_kinds.push_back(SC);
if (prb->use_sh()) check_kinds.push_back(SH);
}
check_kinds.push_back(SRC);
}
// `check_kinds` is empty for `(prb->dir & FLAG_BWD) && (dir & FLAG_FWD)`.
return check_kinds;
}
int createit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
v_prim.resize(2); // just fwd or fwd + bwd.
SAFE(init_prim(prb->ctx_init, v_prim[0], init_pd, prb, res, FLAG_FWD,
nullptr, /* is_service_prim = */ prb->dir & FLAG_BWD),
WARN);
if (prb->dir & FLAG_BWD) {
SAFE(init_prim(prb->ctx_init, v_prim[1], init_pd, prb, res, FLAG_BWD,
query_pd(v_prim[0])),
WARN);
}
return OK;
}
int checkit(std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
if (has_bench_mode_bit(mode_bit_t::exec)) {
SAFE(check_total_size(res), WARN);
if (v_prim[1]) SAFE(check_total_size(res), WARN);
}
if (has_bench_mode_bit(mode_bit_t::corr)) {
SAFE(check_caches(v_prim[0], prb, res), WARN);
if (v_prim[1]) { SAFE(check_caches(v_prim[1], prb, res), WARN); }
}
return OK;
}
int doit(const std::vector<benchdnn_dnnl_wrapper_t<dnnl_primitive_t>> &v_prim,
const prb_t *prb, res_t *res) {
set_zmalloc_max_expected_size(res->mem_size_args.zmalloc_expected_size);
const auto &prim = prb->dir & FLAG_FWD ? v_prim[0] : v_prim[1];
dnn_mem_map_t mem_map, ref_mem_map;
init_memory_args<prb_t>(
mem_map, prb, v_prim[0], supported_exec_args(FLAG_FWD));
TIME_FILL(SAFE(
init_ref_memory_args(ref_mem_map, mem_map, v_prim[0], prb, res),
WARN));
args_t args(mem_map), ref_args(ref_mem_map);
SAFE(execute_and_wait(v_prim[0], args, res), WARN);
check_correctness(prb, get_kinds_to_check(prb, FLAG_FWD), args, ref_args,
setup_cmp, res, FLAG_FWD);
SAFE(check_bitwise(prim, get_kinds_to_check(prb, FLAG_FWD), args, prb->attr,
prb->inplace, res),
WARN);
if (prb->debug_check_ws && has_bench_mode_bit(mode_bit_t::corr))
check_fwd_ws(mem_map, res);
if (prb->dir & FLAG_BWD) {
// Pass same memory map as we need data from forward on backward.
init_memory_args<prb_t>(
mem_map, prb, v_prim[1], supported_exec_args(FLAG_BWD));
TIME_FILL(SAFE(
init_ref_memory_args(ref_mem_map, mem_map, v_prim[1], prb, res),
WARN));
args = args_t(mem_map);
ref_args = args_t(ref_mem_map);
SAFE(execute_and_wait(v_prim[1], args, res), WARN);
check_correctness(prb, get_kinds_to_check(prb, FLAG_BWD), args,
ref_args, setup_cmp, res, FLAG_BWD);
SAFE(check_bitwise(prim, get_kinds_to_check(prb, FLAG_BWD), args,
prb->attr, prb->inplace, res),
WARN);
}
return measure_perf(prb->ctx_exe, res, prim, args);
}
} // namespace bnorm