mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
1153 lines
47 KiB
C++
1153 lines
47 KiB
C++
/*******************************************************************************
|
|
* Copyright 2017-2025 Intel Corporation
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
#ifndef DNNL_COMMON_HPP
|
|
#define DNNL_COMMON_HPP
|
|
|
|
#include <functional>
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
|
|
#include "oneapi/dnnl/dnnl.h"
|
|
|
|
#include "common.hpp"
|
|
#include "dnn_types.hpp"
|
|
#include "dnnl_debug.hpp"
|
|
#include "dnnl_memory.hpp"
|
|
#include "utils/compare.hpp"
|
|
#include "utils/dims.hpp"
|
|
#include "utils/dnnl_query.hpp"
|
|
#include "utils/fill.hpp"
|
|
#include "utils/impl_filter.hpp"
|
|
#include "utils/numeric.hpp"
|
|
#include "utils/parallel.hpp"
|
|
|
|
#include "tests/test_thread.hpp"
|
|
|
|
#define for_ for
|
|
|
|
#define DNN_SAFE(f, s) \
|
|
do { \
|
|
dnnl_status_t status__ = f; \
|
|
if (status__ != dnnl_success) { \
|
|
if ((s) == CRIT || (s) == WARN) { \
|
|
BENCHDNN_PRINT(0, \
|
|
"Error: Function '%s' at (%s:%d) returned '%s'\n", \
|
|
__FUNCTION__, __FILE__, __LINE__, \
|
|
status2str(status__)); \
|
|
fflush(0); \
|
|
if ((s) == CRIT) exit(2); \
|
|
} \
|
|
return FAIL; \
|
|
} \
|
|
} while (0)
|
|
|
|
#define DNN_SAFE_V(f) \
|
|
do { \
|
|
dnnl_status_t status__ = (f); \
|
|
if (status__ != dnnl_success) { \
|
|
BENCHDNN_PRINT(0, \
|
|
"Error: Function '%s' at (%s:%d) returned '%s'\n", \
|
|
__FUNCTION__, __FILE__, __LINE__, status2str(status__)); \
|
|
fflush(0); \
|
|
exit(2); \
|
|
} \
|
|
} while (0)
|
|
|
|
// Unlike `DNN_SAFE` this one returns `dnnl_status_t`, not `OK/FAIL`.
|
|
#define DNN_SAFE_STATUS(f) \
|
|
do { \
|
|
dnnl_status_t status__ = (f); \
|
|
if (status__ != dnnl_success) { return status__; } \
|
|
} while (0)
|
|
|
|
#ifndef DNNL_EXPERIMENTAL_PROFILING
|
|
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL \
|
|
|| DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
|
|
using dnnl_profiling_data_kind_t = int;
|
|
extern "C" dnnl_status_t dnnl_reset_profiling(dnnl_stream_t stream);
|
|
extern "C" dnnl_status_t dnnl_query_profiling_data(dnnl_stream_t stream,
|
|
dnnl_profiling_data_kind_t data_kind, int *num_entries, uint64_t *data);
|
|
#endif
|
|
#endif
|
|
|
|
int check_pd_cache(const_dnnl_primitive_desc_t pd, res_t *res);
|
|
int check_primitive_cache(dnnl_primitive_t p, res_t *res);
|
|
|
|
extern dnnl_engine_kind_t engine_tgt_kind;
|
|
extern size_t engine_index;
|
|
extern isa_hints_t hints;
|
|
extern int default_num_streams;
|
|
extern int num_streams;
|
|
|
|
struct engine_t {
|
|
engine_t(dnnl_engine_kind_t engine_kind);
|
|
engine_t(dnnl_engine_t engine);
|
|
engine_t(const engine_t &other);
|
|
~engine_t();
|
|
operator dnnl_engine_t() const { return engine_; }
|
|
|
|
private:
|
|
engine_t &operator=(engine_t &other) = delete;
|
|
dnnl_engine_t engine_;
|
|
bool is_owner_;
|
|
};
|
|
|
|
struct stream_t {
|
|
stream_t() : stream_(nullptr), is_owner_(false) {}
|
|
stream_t(dnnl_engine_t engine, void *interop_obj = nullptr);
|
|
~stream_t();
|
|
operator dnnl_stream_t() const { return stream_; }
|
|
stream_t &operator=(stream_t &&rhs);
|
|
|
|
private:
|
|
BENCHDNN_DISALLOW_COPY_AND_ASSIGN(stream_t);
|
|
dnnl_stream_t stream_;
|
|
bool is_owner_;
|
|
};
|
|
|
|
// Engine used to run oneDNN primitives for testing.
|
|
inline const engine_t &get_test_engine() {
|
|
static const engine_t instance(engine_tgt_kind);
|
|
return instance;
|
|
}
|
|
|
|
// Engine used to run all reference native implementations and CPU
|
|
// implementations used by `--fast-ref` option.
|
|
inline const engine_t &get_cpu_engine() {
|
|
#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE
|
|
// In case of lacking CPU engine, just re-use testing one.
|
|
return get_test_engine();
|
|
#else
|
|
static const engine_t instance(dnnl_cpu);
|
|
return instance;
|
|
#endif
|
|
}
|
|
|
|
bool is_cpu(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_gpu(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_sycl_engine(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_opencl_engine(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_nvidia_gpu(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_f64_supported(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_amd_gpu(const dnnl_engine_t &engine = get_test_engine());
|
|
bool is_generic_gpu(const dnnl_engine_t &engine = get_test_engine());
|
|
|
|
// Extended version of dnnl_sycl_interop_memory_kind_t enumeration.
|
|
enum class memory_kind_ext_t {
|
|
usm, // Same as dnnl_sycl_interop_usm
|
|
buffer, // Same as dnnl_sycl_interop_buffer
|
|
usm_device, // USM allocated via malloc_device()
|
|
usm_shared, // USM allocated via malloc_shared()
|
|
};
|
|
|
|
const memory_kind_ext_t default_memory_kind = memory_kind_ext_t::usm;
|
|
|
|
extern memory_kind_ext_t memory_kind;
|
|
|
|
void init_isa_settings();
|
|
|
|
struct args_t {
|
|
args_t() = default;
|
|
args_t(const dnn_mem_map_t &mem_map);
|
|
|
|
args_t &set(int arg, const dnn_mem_t &mem);
|
|
void clear() { args_.clear(); }
|
|
|
|
int size() const { return (int)args_.size(); }
|
|
|
|
const dnn_mem_t &find(int arg) const;
|
|
// Used in graph to link arguments together by updating current source with
|
|
// previous destination.
|
|
void replace(int arg, const dnn_mem_t *mem);
|
|
|
|
int arg(int index) const { return args_[index].first; }
|
|
const dnn_mem_t &dnn_mem(int index) const { return *args_[index].second; }
|
|
|
|
private:
|
|
std::vector<std::pair<int, const dnn_mem_t *>> args_;
|
|
};
|
|
|
|
template <typename prb_t>
|
|
struct init_pd_args_t {
|
|
init_pd_args_t(res_t *res, dnnl_engine_t engine, const prb_t *prb,
|
|
dir_t dir, const_dnnl_primitive_desc_t hint,
|
|
const_dnnl_memory_desc_t src_md, bool force_f32_dt)
|
|
: pd(nullptr)
|
|
, is_iterator_supported(true)
|
|
, res(res)
|
|
, engine(engine)
|
|
, prb(prb)
|
|
, dir(dir)
|
|
, hint(hint)
|
|
, src_md(src_md)
|
|
, force_f32_dt(force_f32_dt) {}
|
|
|
|
init_pd_args_t(res_t *res, dnnl_engine_t engine, const prb_t *prb,
|
|
dir_t dir, const_dnnl_primitive_desc_t hint,
|
|
const_dnnl_memory_desc_t src_md)
|
|
: init_pd_args_t(res, engine, prb, dir, hint, src_md, false) {}
|
|
|
|
// Output members
|
|
dnnl_primitive_desc_t pd;
|
|
|
|
bool is_iterator_supported;
|
|
|
|
// Input members
|
|
res_t *res;
|
|
dnnl_engine_t engine;
|
|
const prb_t *prb;
|
|
// Used to specify the prop_kind of the pd. Required for double-run drivers
|
|
// to differentiate between fwd-for-bwd pd and actual bwd pd.
|
|
dir_t dir;
|
|
const_dnnl_primitive_desc_t hint;
|
|
// Use for memory propagation between pd. Nullptr will ignore the setting.
|
|
const_dnnl_memory_desc_t src_md;
|
|
// When `true`, overrides prb data type with f32 for ALL memory descriptors
|
|
// when creating pd objects.
|
|
bool force_f32_dt;
|
|
};
|
|
|
|
struct cpu_cache_args_t {
|
|
size_t L2_size = 0;
|
|
size_t L3_size = 0; // = L3_per_core
|
|
size_t num_cores = 0;
|
|
size_t total_socket_size = 0; // (L2 + L3_per_core) * num_cores
|
|
};
|
|
|
|
size_t get_cpu_ram_size();
|
|
int get_gpu_ram_sizes(size_t &ram_size, size_t &max_alloc_size);
|
|
int get_cpu_cache_size(cpu_cache_args_t &cache_args);
|
|
int get_gpu_cache_size(size_t &cache_size);
|
|
|
|
int check_total_size(res_t *res, dnnl_primitive_t prim_ref = nullptr);
|
|
bool is_fwd_training(dnnl_prop_kind_t prop_kind);
|
|
bool is_fwd_prop_kind(dnnl_prop_kind_t prop_kind);
|
|
int get_memory_footprint(const_dnnl_primitive_desc_t pd, res_t *res);
|
|
int check_same_pd(const dnnl_primitive_desc_t &pd_no_attr, res_t *res);
|
|
int test_persistent_cache_api(
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim, res_t *res);
|
|
// This call is used in zeropad only and still does check inside, too.
|
|
int check_mem_size(const_dnnl_memory_desc_t md, res_t *res);
|
|
// Only collects memory sizes from an input `const_pd` and puts the result into
|
|
// `mem_size_args`.
|
|
int collect_mem_size(check_mem_size_args_t &mem_size_args,
|
|
const_dnnl_primitive_desc_t const_pd, dir_t dir, bool need_skip = true);
|
|
|
|
inline bool should_stop(const timer::timer_t &t) {
|
|
const bool stop = false
|
|
|| (fix_times_per_prb && t.times() >= fix_times_per_prb)
|
|
|| (!fix_times_per_prb && t.total_ms() >= max_ms_per_prb
|
|
&& t.times() >= min_times_per_prb);
|
|
return stop;
|
|
}
|
|
|
|
void skip_unimplemented_data_type(
|
|
const std::vector<dnnl_data_type_t> &v_dt, dir_t dir, res_t *res);
|
|
void skip_unimplemented_sum_po(const attr_t &attr, res_t *res,
|
|
dnnl_primitive_kind_t pkind, dnnl_data_type_t src_dt,
|
|
dnnl_data_type_t dst_dt = dnnl_data_type_undef);
|
|
void skip_unimplemented_binary_po(const attr_t &attr, res_t *res);
|
|
void skip_unimplemented_prelu_po(
|
|
const attr_t &attr, res_t *res, dnnl_primitive_kind_t pkind);
|
|
void skip_invalid_inplace(res_t *res, dnnl_data_type_t sdt,
|
|
dnnl_data_type_t ddt, const std::string &stag, const std::string &dtag);
|
|
void skip_unimplemented_arg_scale(const attr_t &attr, res_t *res);
|
|
|
|
template <typename prb_t>
|
|
int check_caches(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &primw,
|
|
const prb_t *prb, res_t *res) {
|
|
if (!primw) return OK;
|
|
|
|
// Under assumption of a limited cache capacity, which is usually the case,
|
|
// any check that rely on primitive (descriptor) resides in cache fails for
|
|
// parallel creation modifier. There's no guarantee in a general case that
|
|
// primitive stays in cache till the the moment the check happens. Even with
|
|
// primitive saving a state of picked from cache or not, there's a time gap
|
|
// between creation of two same primitives over different engines where
|
|
// first instance could be evicted. This approach may work under assumption
|
|
// of infinite cache though.
|
|
if (!has_bench_mode_modifier(mode_modifier_t::par_create)) {
|
|
const_dnnl_primitive_desc_t pd = query_pd(primw);
|
|
SAFE(create_in_thr_ctx(prb->ctx_init, check_pd_cache, pd, res), WARN);
|
|
// Check primitive is picked up from the cache if applicable.
|
|
SAFE(create_in_thr_ctx(
|
|
prb->ctx_init, check_primitive_cache, primw, res),
|
|
WARN);
|
|
}
|
|
|
|
// Check primitive is picked up from the persistent cache if applicable.
|
|
// Note: primw get re-written here to put a primitive from cache blob, if
|
|
// GPU backend is OCL.
|
|
SAFE(test_persistent_cache_api(primw, res), WARN);
|
|
|
|
return OK;
|
|
}
|
|
|
|
// `check_dnnl_status` function is called to validate the result of primitive
|
|
// descriptor creation. Based on the status, it produces additional checks:
|
|
// * For `invalid_arguments` it just updates the `res` object with it.
|
|
// * For `unimplemented` it checks whether the lack of support is expected or
|
|
// not. It relies on `skip_unimplemented_prb` function declared and defined
|
|
// at every driver and expects it to find in correspondent namespace from
|
|
// where `prb_t` was picked up. If the case is unknown, `UNIMPLEMENTED` status
|
|
// will be returned.
|
|
template <typename prb_t>
|
|
int check_dnnl_status(dnnl_status_t status, const prb_t *prb, res_t *res) {
|
|
if (!res || status == dnnl_success) return OK;
|
|
|
|
switch (status) {
|
|
case dnnl_invalid_arguments: res->state = INVALID_ARGUMENTS; break;
|
|
case dnnl_unimplemented: {
|
|
// Unconditionally set all Nvidia backend unimplemented cases as
|
|
// not supported.
|
|
if (is_nvidia_gpu() || is_amd_gpu()) {
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::case_not_supported;
|
|
return OK;
|
|
}
|
|
|
|
// Check driver specific cases of unimplemented functionality.
|
|
skip_unimplemented_prb(prb, res);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
// If the case is not known to be skipped, it is unimplemented.
|
|
res->state = UNIMPLEMENTED;
|
|
} break;
|
|
default: assert(!"unexpected");
|
|
}
|
|
DNN_SAFE(status, WARN);
|
|
return OK;
|
|
}
|
|
|
|
// `fetch_impl` is responsible to provide a valid `pd` under certain conditions:
|
|
// 1. Either valid `pd` or `pd_it` were provided.
|
|
// 2a. It's a service primitive (fwd-for-bwd or cpu-for-gpu or
|
|
// simple-prims-of-complex-prim).
|
|
// 2b. It's a tested primitive and not all implementations hit skip-impl option
|
|
// values.
|
|
//
|
|
// Note: `res` can be empty when fetching impl for prim_ref support.
|
|
template <typename prb_t>
|
|
int fetch_impl(benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> &pdw,
|
|
init_pd_args_t<prb_t> &init_pd_args, const impl_filter_t &impl_filter,
|
|
res_t *res, bool is_service_prim) {
|
|
if (!init_pd_args.pd) return FAIL;
|
|
|
|
// Wrapper is expected to come empty.
|
|
assert(!pdw);
|
|
|
|
pdw.reset(init_pd_args.pd);
|
|
|
|
// Service primitive is not supposed to utilize further logic.
|
|
if (is_service_prim) return OK;
|
|
|
|
while (true) {
|
|
const auto impl_name = query_impl_info(pdw);
|
|
if (!need_next_impl(impl_name, impl_filter)) return OK;
|
|
|
|
BENCHDNN_PRINT(6, "[IMPL_FILTER] Implementation skipped: %s\n",
|
|
impl_name.c_str());
|
|
|
|
// Iterator is not supported, further logic is not applicable.
|
|
if (!init_pd_args.is_iterator_supported) {
|
|
if (res) res->state = SKIPPED;
|
|
if (res) res->reason = skip_reason::skip_impl_hit;
|
|
return OK;
|
|
}
|
|
|
|
auto status = dnnl_primitive_desc_next_impl(pdw);
|
|
if (status == dnnl_last_impl_reached) {
|
|
BENCHDNN_PRINT(2, "%s\n",
|
|
"[IMPL_FILTER] All implementations were skipped!");
|
|
if (res) res->state = SKIPPED;
|
|
if (res) res->reason = skip_reason::skip_impl_hit;
|
|
pdw.reset(nullptr);
|
|
return OK;
|
|
} else if (status == dnnl_success) {
|
|
continue;
|
|
} else {
|
|
BENCHDNN_PRINT(0, "%s\n",
|
|
"[IMPL_FILTER] Unexpected status from pd iterator.");
|
|
return FAIL;
|
|
}
|
|
}
|
|
|
|
// Unreached fail status.
|
|
return FAIL;
|
|
}
|
|
|
|
// This is an internal to `init_prim` function that utilizes the logic of
|
|
// creating a `pd` and `prim` and assign them to input wrappers. It allows to
|
|
// remove code duplication and keep all the logic in a single place.
|
|
template <typename func_t, typename prb_t>
|
|
int create_primitive(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &primw,
|
|
dnnl_engine_t engine, const func_t &init_pd_func, const prb_t *prb,
|
|
res_t *res, dir_t dir, const_dnnl_primitive_desc_t hint,
|
|
bool is_service_prim, const_dnnl_memory_desc_t src_md,
|
|
bool force_f32_dt, bool is_graph_ref = false) {
|
|
dnnl_status_t status = dnnl_success;
|
|
dnnl_primitive_t prim {};
|
|
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
|
|
|
|
init_pd_args_t<prb_t> init_pd_args(
|
|
res, engine, prb, dir, hint, src_md, force_f32_dt);
|
|
status = init_pd_func(init_pd_args);
|
|
|
|
SAFE(check_dnnl_status(status, prb, res), WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
// Fetch also checks if user requested to skip certain implementations.
|
|
SAFE(fetch_impl(pdw, init_pd_args, prb->impl_filter, res, is_service_prim),
|
|
WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
// Note: Graph may contain more than one operation with identical `dir`.
|
|
// It's required to collect all memory sizes regardless of `dir`.
|
|
SAFE(collect_mem_size(res->mem_size_args, pdw, dir,
|
|
/* need_skip = */ !is_graph_ref),
|
|
WARN);
|
|
|
|
// The library scratchpad is allocated at create_primitive stage. The memory
|
|
// check is moved after the creation stage. It's necessary to check the
|
|
// library scratchpad size against gpu_max_alloc, otherwise, out_of_memory
|
|
// would be issued by the library.
|
|
if (res->mem_size_args.scratchpad_size > 0 && is_gpu()
|
|
&& query_scratchpad_mode(query_attr(pdw))
|
|
== dnnl_scratchpad_mode_library) {
|
|
static size_t gpu_device_capacity = 0;
|
|
static size_t gpu_max_alloc_capacity = 0;
|
|
SAFE(get_gpu_ram_sizes(gpu_device_capacity, gpu_max_alloc_capacity),
|
|
WARN);
|
|
const bool fit
|
|
= res->mem_size_args.scratchpad_size < gpu_max_alloc_capacity;
|
|
if (!fit) {
|
|
BENCHDNN_PRINT(1,
|
|
"[CHECK_MEM]: Size of the scratchpad %s "
|
|
"doesn't fit the allocation limit of %s.\n",
|
|
smart_bytes(res->mem_size_args.scratchpad_size).c_str(),
|
|
smart_bytes(gpu_max_alloc_capacity).c_str());
|
|
res->state = SKIPPED;
|
|
res->reason = skip_reason::not_enough_ram;
|
|
return OK;
|
|
}
|
|
}
|
|
|
|
TIME_C_PRIM(DNN_SAFE(dnnl_primitive_create(&prim, pdw), WARN));
|
|
primw.reset(prim);
|
|
|
|
return OK;
|
|
}
|
|
|
|
template <typename func_t, typename prb_t>
|
|
int check_pd_w_and_wo_attr(dnnl_engine_t engine, const func_t &init_pd_func,
|
|
const prb_t *prb, res_t *res, dir_t dir,
|
|
const_dnnl_primitive_desc_t hint) {
|
|
|
|
if (!attr_same_pd_check || prb->attr.is_def()) return OK;
|
|
|
|
if (prb->attr.post_ops.convolution_index() != -1) return OK;
|
|
|
|
// Check that adding attributes doesn't cause a fall back to another impl.
|
|
auto *prb_mutable = const_cast<prb_t *>(prb);
|
|
auto old_attr = prb_mutable->attr;
|
|
prb_mutable->attr = attr_t();
|
|
init_pd_args_t<prb_t> init_pd_args_without_attr(
|
|
res, engine, prb_mutable, dir, hint, /* src_md = */ nullptr);
|
|
DNN_SAFE(init_pd_func(init_pd_args_without_attr), WARN);
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw(
|
|
init_pd_args_without_attr.pd);
|
|
prb_mutable->attr = old_attr;
|
|
SAFE(check_same_pd(pdw, res), WARN);
|
|
return OK;
|
|
}
|
|
|
|
int check_ref_impl_hit(res_t *res);
|
|
|
|
template <typename func_t, typename prb_t>
|
|
int init_prim(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim,
|
|
const func_t &init_pd_func, const prb_t *prb, res_t *res,
|
|
dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr,
|
|
bool is_service_prim = false) {
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_t> primw;
|
|
|
|
skip_invalid_prb(prb, res);
|
|
if (res->state == SKIPPED) return OK;
|
|
#ifndef DNNL_DISABLE_PRIMITIVE_CACHE
|
|
|
|
int capacity = 0;
|
|
DNN_SAFE(dnnl_get_primitive_cache_capacity(&capacity), FAIL);
|
|
if (capacity > 0) {
|
|
// The idea is to create the requested primitive twice using different
|
|
// engines but the same device and context in the case of OpenCL and DPCPP.
|
|
// Rationale: make sure that the primitive cache is robust in the case
|
|
// where CPU and GPU engines are re-created because this is a commonly
|
|
// used scenario in the frameworks.
|
|
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
|
|
engine_t engine(get_test_engine());
|
|
|
|
// The first primitive creation using a temporary engine.
|
|
SAFE(create_primitive(primw, engine, init_pd_func, prb, res, dir, hint,
|
|
is_service_prim, /* src_md = */ nullptr,
|
|
/* force_f32_dt = */ false),
|
|
WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
}
|
|
|
|
#endif
|
|
// The second (if the cache is enabled) primitive creation using the global
|
|
// test engine. This primitive is expected to come from the cache.
|
|
SAFE(create_primitive(primw, get_test_engine(), init_pd_func, prb, res, dir,
|
|
hint, is_service_prim, /* src_md = */ nullptr,
|
|
/* force_f32_dt = */ false),
|
|
WARN);
|
|
if (res->state == SKIPPED) return OK;
|
|
|
|
// Further checks are only for tested primitives.
|
|
if (is_service_prim) {
|
|
user_prim.reset(primw.release());
|
|
return OK;
|
|
}
|
|
|
|
auto pd = query_pd(primw);
|
|
res->impl_name = query_impl_info(pd);
|
|
BENCHDNN_PRINT(5, "oneDNN implementation: %s\n", res->impl_name.c_str());
|
|
// Collect memory footprint (perf report) for a given primitive descriptor.
|
|
SAFE(get_memory_footprint(pd, res), WARN);
|
|
|
|
if (has_bench_mode_bit(mode_bit_t::corr)) {
|
|
// Check if adding attributes doesn't cause a fall back to another impl.
|
|
SAFE(check_pd_w_and_wo_attr(
|
|
get_test_engine(), init_pd_func, prb, res, dir, hint),
|
|
WARN);
|
|
// Check if unexpected ref impl was hit.
|
|
SAFE(check_ref_impl_hit(res), WARN);
|
|
}
|
|
|
|
user_prim.reset(primw.release());
|
|
return res->state = INITIALIZED, OK;
|
|
}
|
|
|
|
template <typename func_t, typename prb_t>
|
|
int init_prim(const thr_ctx_t &thr_ctx,
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &user_prim,
|
|
const func_t &init_pd_func, prb_t *prb, res_t *res,
|
|
dir_t dir = FLAG_FWD, const_dnnl_primitive_desc_t hint = nullptr,
|
|
bool is_service_prim = false) {
|
|
int (*f)(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &, func_t &,
|
|
const prb_t *, res_t *, dir_t, const_dnnl_primitive_desc_t, bool)
|
|
= init_prim<func_t, prb_t>;
|
|
return create_in_thr_ctx(thr_ctx, f, user_prim, init_pd_func, prb, res, dir,
|
|
hint, is_service_prim);
|
|
}
|
|
|
|
// `check_correctness` function is designed to be called from every driver where
|
|
// correctness validation is needed. It takes:
|
|
// * A pointer to a `prb_t` problem.
|
|
// * A vector of kinds to compare, to validate several outputs, if applicable.
|
|
// * Backend arguments to compare the output.
|
|
// * Driver's reference memory arguments to compute the reference path, then
|
|
// setup a compare object, and, finally, compare the output.
|
|
// * A reference to function that sets up the compare object, see description
|
|
// below.
|
|
// * A pointer to a `res_t` structure to update validation status.
|
|
// * An optional pointer to CPU primitive for speeding up reference path
|
|
// computation on GPU.
|
|
//
|
|
// The function doesn't return status since we rely on `res` to contain all
|
|
// necessary information about validation results.
|
|
//
|
|
// The function performs several validation steps:
|
|
// * Checks that padded area of all memories are properly zeroed.
|
|
// * Checks that GPU backend haven't modified out-of-boundary memory regions.
|
|
// * Executes driver's reference path, using the problem, driver reference
|
|
// arguments, and CPU primitive for GPU backend, if available.
|
|
// * For each kind to validate it:
|
|
// - Creates and sets up the compare object. Setting is done with
|
|
// `setup_cmp_func`.
|
|
// - Finds correspondent memory arguments from backend and reference and
|
|
// compares them.
|
|
// - Result of comparison is saved into `res` object.
|
|
//
|
|
// `setup_cmp_func` is a function that supposed to be defined in every driver's
|
|
// namespace. Its interface is:
|
|
// `void (compare::compare_t &, const prb_t *, data_kind_t, const args_t &);`
|
|
// It takes:
|
|
// * A reference to a `compare_t` object which the function modifies based on
|
|
// driver's needs.
|
|
// * A pointer to a `prb_t` problem.
|
|
// * `data_kind` value to help to setup threshold depending on output argument.
|
|
// * Driver's reference memory arguments since some drivers can't validate
|
|
// certain scenarios for sure without additional memory arguments.
|
|
// Returns nothing since the object is modified by reference due to lifetime of
|
|
// the compare object is controlled by `check_correctness`.
|
|
//
|
|
// Note: a dedicated non-templated type for `setup_cmp_func_t` could be used but
|
|
// since it relies on a `prb_t` type which is individual for each driver,
|
|
// it is'nt possible without a template.
|
|
template <typename setup_cmp_func_t, typename prb_t>
|
|
void check_correctness(const prb_t *prb, const std::vector<data_kind_t> &kinds,
|
|
const args_t &args, const args_t &ref_args,
|
|
const setup_cmp_func_t &setup_cmp_func, res_t *res, dir_t dir,
|
|
dnnl_primitive_t prim_ref = nullptr) {
|
|
// Fast exit for any modes but correctness.
|
|
if (!has_bench_mode_bit(mode_bit_t::corr)) return;
|
|
|
|
// Report prim_ref run status for easier distinguishing between GPU failures
|
|
// and ref CPU failures.
|
|
if (prim_ref) {
|
|
BENCHDNN_PRINT(1, "run ref: %s\n", res->prim_ref_repro.c_str());
|
|
} else {
|
|
BENCHDNN_PRINT(8, "%s\n", "[NAIVE_REF]: Start");
|
|
}
|
|
|
|
TIME_REF(compute_ref(prb, dir, ref_args, prim_ref));
|
|
|
|
// Forward-for-backward service primitives define `kinds` as empty to skip
|
|
// validation. This is to avoid extra checks on higher level.
|
|
if (kinds.empty()) return;
|
|
|
|
for (int i = 0; i < args.size(); ++i) {
|
|
TIME_COMPARE(check_zero_padding(args.dnn_mem(i), args.arg(i), res));
|
|
TIME_COMPARE(check_buffer_overwrite(args.dnn_mem(i), args.arg(i), res));
|
|
}
|
|
|
|
for (const auto &kind : kinds) {
|
|
compare::compare_t cmp;
|
|
cmp.set_data_kind(kind);
|
|
cmp.set_has_prim_ref(bool(prim_ref));
|
|
setup_cmp_func(cmp, prb, kind, ref_args);
|
|
|
|
int arg = data_kind2exec_arg(kind);
|
|
assert(arg > 0);
|
|
|
|
const auto &mem_dt = args.find(arg);
|
|
const auto &mem_fp = ref_args.find(arg);
|
|
|
|
TIME_COMPARE(cmp.compare(mem_fp, mem_dt, prb->attr, res));
|
|
}
|
|
|
|
if (prim_ref && res->state == FAILED) {
|
|
static cpu_cache_args_t cpu_cache_args {};
|
|
SAFE_V(get_cpu_cache_size(cpu_cache_args));
|
|
|
|
BENCHDNN_PRINT(0,
|
|
"[PRIM_REF][INFO]: L2_size:%zu bytes; per_core_L3_size:%zu "
|
|
"bytes; nthr:%d; impl_name:%s\n",
|
|
cpu_cache_args.L2_size, cpu_cache_args.L3_size,
|
|
benchdnn_get_max_threads(),
|
|
query_impl_info(query_pd(prim_ref)).c_str());
|
|
}
|
|
}
|
|
|
|
using perf_function_t = std::function<dnnl_status_t(
|
|
const dnnl_stream_t &, const std::vector<dnnl_exec_arg_t> &)>;
|
|
|
|
int execute_and_wait(perf_function_t &exec_func, const dnnl_engine_t &engine,
|
|
const args_t &args, res_t *res = nullptr);
|
|
int execute_and_wait(
|
|
dnnl_primitive_t prim, const args_t &args, res_t *res = nullptr);
|
|
|
|
void reset_gpu_profiling(dnnl_stream_t stream);
|
|
|
|
void finalize();
|
|
|
|
int get_gpu_profiling_info(dnnl_stream_t stream, std::vector<uint64_t> &nsecs,
|
|
std::vector<uint64_t> &cycles, int expected_num_entries);
|
|
int measure_perf(const thr_ctx_t &ctx, res_t *res, perf_function_t &perf_func,
|
|
args_t &args);
|
|
int measure_perf(
|
|
const thr_ctx_t &ctx, res_t *res, dnnl_primitive_t prim, args_t &args);
|
|
|
|
std::vector<float> prepare_po_vals(const dnn_mem_t &dst_m, const args_t &args,
|
|
const std::vector<std::pair<int, int>> &v_po_masks,
|
|
const size_t dst_off);
|
|
|
|
bool check_md_consistency_with_tag(
|
|
const_dnnl_memory_desc_t md, const std::string &tag);
|
|
|
|
memory_kind_ext_t str2memory_kind(const char *str);
|
|
|
|
enum class execution_mode_t { direct, graph };
|
|
extern execution_mode_t execution_mode;
|
|
|
|
const char *execution_mode2str(execution_mode_t mode);
|
|
execution_mode_t str2execution_mode(const char *str);
|
|
|
|
float reorder_rescale_factor();
|
|
|
|
// The function converts a memory descriptor dims into a `dims_t` object under
|
|
// certain rules.
|
|
//
|
|
// `mask` argument picks what dimensions to put into a new object as is.
|
|
// `extend_by_ones` specifies the behavior with dimensions not matched by
|
|
// `mask`. When set to `true` (the default), a dim value of `1` is used
|
|
// for a not matched dimension. Thus, `ndims` of a new object will remain
|
|
// the same as for original md. When set to `false`, a dim is skipped and
|
|
// the final object could end up with smaller `ndims` (or `size()`) value.
|
|
// `groups` specify a vector of group values which decrease the final dimension
|
|
// values by dividing on the group size.
|
|
dims_t md2dims(const_dnnl_memory_desc_t md, int mask = -1,
|
|
bool extend_by_ones = true, const std::vector<int64_t> &groups = {});
|
|
|
|
// Function adjusts data type if fpmath mode is present or sum_dt is different
|
|
// from destination_dt. It is used in `cfg` objects that regulate filling.
|
|
dnnl_data_type_t deduce_cfg_data_type(
|
|
dnnl_data_type_t in_dt, const attr_t &attr, data_kind_t dk);
|
|
|
|
// `init_memory_args` is responsible for:
|
|
// * Constructing all necessary `dnn_mem_t` objects needed by the library
|
|
// primitive for the main operation and attributes.
|
|
// All these memories must utilize `prefill_memory` flag of `dnn_mem_t` ctor
|
|
// to verify reorders are working correctly and output memory was updated
|
|
// completely.
|
|
// * Stashing them with a proper exec_arg ID in a `mem_map` object.
|
|
// Caller is responsible for constructing reference memories and filling both
|
|
// the library and reference memories by calling `init_ref_memory_args`.
|
|
//
|
|
// Note: unordered_map is taken over std::vector because vector invalidates its
|
|
// references once the object emplaced due to memory re-allocations happening
|
|
// internally, while map doesn't not invalidate its references when adding a new
|
|
// element which simplifies an implementation.
|
|
template <typename prb_t>
|
|
void init_memory_args(dnn_mem_map_t &mem_map, const prb_t *prb,
|
|
dnnl_primitive_t prim, const std::vector<int> &supported_exec_args,
|
|
const engine_t &test_engine = get_test_engine()) {
|
|
// Backward case when forward is required will have mem_map not empty.
|
|
// Remove all memories that are not in `supported_exec_args` to save on
|
|
// initializing reference memories.
|
|
//
|
|
// Note: this code is pretty similar to the one in `erase_unused_args` but
|
|
// with a slight change where the key is checked and bitwise correction.
|
|
if (!mem_map.empty()) {
|
|
// Collection of keys is required as evicting members along the way
|
|
// invalidates references and makes traversing over the object
|
|
// undefined behavior.
|
|
std::vector<int> keys_to_erase;
|
|
for (const auto &pair : mem_map) {
|
|
const auto key = pair.first;
|
|
bool key_found_in_exec_args = false;
|
|
for (const auto &arg : supported_exec_args) {
|
|
if (arg == key) {
|
|
key_found_in_exec_args = true;
|
|
break;
|
|
}
|
|
}
|
|
// Don't remove stashed memory for bitwise validation.
|
|
const bool bitwise_stash
|
|
= has_bench_mode_bit(mode_bit_t::bitwise) && key < 0;
|
|
bool add_key_to_erase = !key_found_in_exec_args && !bitwise_stash;
|
|
if (add_key_to_erase) keys_to_erase.push_back(key);
|
|
}
|
|
for (const auto &k : keys_to_erase) {
|
|
mem_map.erase(k);
|
|
}
|
|
}
|
|
|
|
auto const_pd = query_pd(prim);
|
|
auto const_po = query_post_ops(const_pd);
|
|
const auto prim_kind = query_prim_kind(const_pd);
|
|
const auto prop_kind = query_prop_kind(const_pd);
|
|
|
|
const auto has_runtime_dims = [](const_dnnl_memory_desc_t md) -> bool {
|
|
for (int d = 0; d < query_md_ndims(md); ++d)
|
|
if (query_md_dims(md)[d] == DNNL_RUNTIME_DIM_VAL) return true;
|
|
return false;
|
|
};
|
|
|
|
if (prim_kind == dnnl_reorder) {
|
|
auto src_engine = query_engine(const_pd, dnnl_query_reorder_src_engine);
|
|
auto dst_engine = query_engine(const_pd, dnnl_query_reorder_dst_engine);
|
|
const auto &src_md = query_md(const_pd, DNNL_ARG_FROM);
|
|
const auto &dst_md = query_md(const_pd, DNNL_ARG_TO);
|
|
if (has_runtime_dims(src_md)) {
|
|
mem_map.emplace(DNNL_ARG_FROM,
|
|
dnn_mem_t(prb->get_md(DNNL_ARG_FROM), src_engine,
|
|
/* prefill = */ true));
|
|
mem_map.emplace(DNNL_ARG_TO,
|
|
dnn_mem_t(prb->get_md(DNNL_ARG_TO), dst_engine,
|
|
/* prefill = */ true));
|
|
} else {
|
|
mem_map.emplace(DNNL_ARG_FROM,
|
|
dnn_mem_t(src_md, src_engine, /* prefill = */ true));
|
|
mem_map.emplace(DNNL_ARG_TO,
|
|
dnn_mem_t(dst_md, dst_engine, /* prefill = */ true));
|
|
}
|
|
} else {
|
|
for (const auto &exec_arg : supported_exec_args) {
|
|
if (exec_arg == DNNL_ARG_MULTIPLE_SRC) {
|
|
// `DNNL_ARG_MULTIPLE_SRC` corresponds to a pack of inputs.
|
|
const auto n_inputs = query_n_inputs(const_pd);
|
|
for (int i = 0; i < n_inputs; i++) {
|
|
const auto &md = query_md(const_pd, exec_arg + i);
|
|
mem_map.emplace(exec_arg + i,
|
|
dnn_mem_t(md, test_engine, /* prefill = */ true));
|
|
}
|
|
} else {
|
|
const bool is_arg_in_map
|
|
= mem_map.find(exec_arg) != mem_map.end();
|
|
const auto &md = query_md(const_pd, exec_arg);
|
|
// Check for ndims is needed when the driver supported args map
|
|
// contains extra arguments for other purposes.
|
|
const int ndims = query_md_ndims(md);
|
|
if (is_arg_in_map) {
|
|
// It may happen that the map already has the argument but
|
|
// the library requires it in a different format, e.g., RNN
|
|
// BWD support on GPU (for better performance). It may also
|
|
// happen in a combination with `no_ref_memory` modifier,
|
|
// which requires the library memories map to handle such
|
|
// cases.
|
|
if (ndims > 0
|
|
&& dnnl_memory_desc_equal(
|
|
md, mem_map.at(exec_arg).md_)
|
|
== 0) {
|
|
assert(!has_runtime_dims(md));
|
|
dnn_mem_t new_mem(
|
|
md, test_engine, /* prefill = */ true);
|
|
// Reorder user's data from the old memory to the new one.
|
|
auto st = new_mem.reorder(mem_map.at(exec_arg));
|
|
assert(st == OK);
|
|
if (st != OK) return;
|
|
mem_map[exec_arg] = std::move(new_mem);
|
|
}
|
|
} else {
|
|
if (has_runtime_dims(md)) {
|
|
mem_map.emplace(exec_arg,
|
|
dnn_mem_t(prb->get_md(exec_arg), test_engine,
|
|
/* prefill = */ true));
|
|
} else {
|
|
// In case when arguments get updated on backward when
|
|
// forward is required, `emplace` guarantees newly
|
|
// constructed element will be destroyed if an element
|
|
// with a key already present in the map. C++17 could
|
|
// use try_emplace instead to mitigate
|
|
// construction/destruction overhead.
|
|
mem_map.emplace(exec_arg,
|
|
dnn_mem_t(
|
|
md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Drop "destination" memory for in-place case. `args` will take care of
|
|
// setting proper pointers to make in-place mode happen.
|
|
// Note: must precede bitwise stash memory insertion to keep numbers
|
|
// estimated by memory checker correct.
|
|
if (prb->inplace) {
|
|
const bool inplace_fwd = (prb->dir & FLAG_FWD);
|
|
const bool inplace_bwd
|
|
= (prb->dir & FLAG_BWD) && !is_fwd_prop_kind(prop_kind);
|
|
if (inplace_fwd || inplace_bwd) {
|
|
const int inplace_dst_arg
|
|
= (prb->dir & FLAG_FWD) ? DNNL_ARG_DST : DNNL_ARG_DIFF_SRC;
|
|
mem_map[inplace_dst_arg] = dnn_mem_t();
|
|
}
|
|
}
|
|
|
|
// Bitwise mode demands exactly the same inputs between two runs. There are
|
|
// certain scenarios that affect original memory objects content. When such
|
|
// scenarios occur, memory objects have their original content overwritten.
|
|
// The logic below stashes additional memory objects for a copy of data
|
|
// which will get reordered before the second run.
|
|
//
|
|
// An implementation detail:
|
|
// All such memory objects' counterparts are created with the same arg value
|
|
// but with a negative sign. This is the only guaranteed value that is not
|
|
// used by the library.
|
|
if (has_bench_mode_bit(mode_bit_t::bitwise)) {
|
|
// A sum post-op has the destination memory data overwritten by the
|
|
// accumulation memory.
|
|
if (query_post_ops_has_kind(const_po, dnnl_sum)) {
|
|
const int query_arg = DNNL_ARG_DST;
|
|
const int insert_arg = -query_arg;
|
|
const auto &md = query_md(const_pd, query_arg);
|
|
if (has_runtime_dims(md)) {
|
|
mem_map.emplace(insert_arg,
|
|
dnn_mem_t(prb->get_md(query_arg), test_engine,
|
|
/* prefill = */ true));
|
|
} else {
|
|
mem_map.emplace(insert_arg,
|
|
dnn_mem_t(md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
|
|
// An inplace mode uses the source memory object as the destination one.
|
|
// It results in the source is overwritten after the operation is done.
|
|
if (prb->inplace) {
|
|
const bool has_multiple_args = std::any_of(
|
|
supported_exec_args.begin(), supported_exec_args.end(),
|
|
[](int arg) { return arg == DNNL_ARG_MULTIPLE_SRC; });
|
|
const auto query_arg = is_fwd_prop_kind(prop_kind)
|
|
? (has_multiple_args ? DNNL_ARG_MULTIPLE_SRC : DNNL_ARG_SRC)
|
|
: DNNL_ARG_DIFF_DST;
|
|
const int insert_arg = -query_arg;
|
|
const auto &md = query_md(const_pd, query_arg);
|
|
if (has_runtime_dims(md)) {
|
|
mem_map.emplace(insert_arg,
|
|
dnn_mem_t(prb->get_md(query_arg), test_engine,
|
|
/* prefill = */ true));
|
|
} else {
|
|
mem_map.emplace(insert_arg,
|
|
dnn_mem_t(md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
}
|
|
|
|
const auto &scratch_md = query_md(const_pd, DNNL_ARG_SCRATCHPAD);
|
|
mem_map.emplace(DNNL_ARG_SCRATCHPAD,
|
|
dnn_mem_t(scratch_md, test_engine, /* prefill = */ true));
|
|
|
|
// Binary post-op.
|
|
// TODO: currently run-time dimensions are not supported in binary post-op.
|
|
for (int idx = 0; idx < dnnl_post_ops_len(const_po); ++idx) {
|
|
if (dnnl_post_ops_get_kind(const_po, idx) != dnnl_binary) continue;
|
|
|
|
int po_arg1 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_1;
|
|
const auto &po_md1 = query_md(const_pd, po_arg1);
|
|
mem_map.emplace(
|
|
po_arg1, dnn_mem_t(po_md1, test_engine, /* prefill = */ true));
|
|
|
|
if (!query_post_ops_has_binary_alg_kind(
|
|
const_po, idx, dnnl_binary_select))
|
|
continue;
|
|
int po_arg2 = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_SRC_2;
|
|
const auto &po_md2 = query_md(const_pd, po_arg2);
|
|
mem_map.emplace(
|
|
po_arg2, dnn_mem_t(po_md2, test_engine, /* prefill = */ true));
|
|
}
|
|
|
|
// Prelu post-op.
|
|
for (int idx = 0; idx < dnnl_post_ops_len(const_po); ++idx) {
|
|
if (dnnl_post_ops_get_kind(const_po, idx) != dnnl_prelu) continue;
|
|
|
|
const auto &orig_dst_md = query_md(const_pd, DNNL_ARG_DST);
|
|
benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> prb_dst_md;
|
|
if (has_runtime_dims(orig_dst_md)) {
|
|
prb_dst_md = prb->get_md(DNNL_ARG_DST);
|
|
}
|
|
const auto &dst_md = prb_dst_md ? prb_dst_md : orig_dst_md;
|
|
|
|
const auto ndims = query_md_ndims(dst_md);
|
|
int mask = 0;
|
|
dnnl_post_ops_get_params_prelu(const_po, idx, &mask);
|
|
|
|
// Deduce prelu weights dims based on input policy.
|
|
dims_t dims = md2dims(dst_md, mask);
|
|
|
|
int po_arg = DNNL_ARG_ATTR_MULTIPLE_POST_OP(idx) | DNNL_ARG_WEIGHTS;
|
|
mem_map.emplace(po_arg,
|
|
dnn_mem_t(ndims, dims.data(), dnnl_f32, tag::axb, test_engine,
|
|
/* prefill = */ true));
|
|
}
|
|
|
|
// Dropout
|
|
if (is_fwd_training(prop_kind) && !prb->attr.dropout.is_def()) {
|
|
const auto &dropout_md = query_md(const_pd, DNNL_ARG_ATTR_DROPOUT_MASK);
|
|
mem_map.emplace(DNNL_ARG_ATTR_DROPOUT_MASK,
|
|
dnn_mem_t(dropout_md, test_engine, /* prefill = */ true));
|
|
int64_t count = 1;
|
|
auto prob_md = dnn_mem_t::init_md(1, &count, dnnl_f32, tag::abx);
|
|
mem_map.emplace(DNNL_ARG_ATTR_DROPOUT_PROBABILITY,
|
|
dnn_mem_t(prob_md, test_engine, /* prefill = */ true));
|
|
|
|
auto seed_md = dnn_mem_t::init_md(1, &count, dnnl_s32, tag::abx);
|
|
mem_map.emplace(DNNL_ARG_ATTR_DROPOUT_SEED,
|
|
dnn_mem_t(seed_md, test_engine, /* prefill = */ true));
|
|
}
|
|
|
|
// Scales.
|
|
if (!prb->attr.scales.is_def()) {
|
|
const auto &sc = prb->attr.scales;
|
|
|
|
const auto &src_md = query_md(const_pd, DNNL_ARG_SRC);
|
|
const auto &wei_md = query_md(const_pd, DNNL_ARG_WEIGHTS);
|
|
const bool has_channel_groups
|
|
= (query_md_ndims(src_md) + 1) == query_md_ndims(wei_md);
|
|
|
|
const auto append_scales = [&](int exec_arg) {
|
|
const int exec_sc_arg = DNNL_ARG_ATTR_SCALES | exec_arg;
|
|
dims_t dims = {};
|
|
int64_t ndims = 1;
|
|
const auto mask = sc.get_mask(exec_arg, prim_kind,
|
|
query_md_ndims(wei_md), has_channel_groups);
|
|
const auto &groups = sc.get(exec_arg).groups;
|
|
|
|
if (mask > 0) {
|
|
const auto &md = query_md(const_pd, exec_arg);
|
|
if (has_runtime_dims(md)) {
|
|
const auto prb_md = prb->get_md(exec_arg);
|
|
dims = md2dims(prb_md, mask, false, groups);
|
|
ndims = static_cast<int>(dims.size());
|
|
} else {
|
|
dims = md2dims(md, mask, false, groups);
|
|
ndims = static_cast<int>(dims.size());
|
|
}
|
|
} else {
|
|
dims = {1};
|
|
ndims = 1;
|
|
}
|
|
const auto dt = sc.get(exec_arg).dt;
|
|
auto scales_md
|
|
= dnn_mem_t::init_md(ndims, dims.data(), dt, tag::abx);
|
|
mem_map.emplace(exec_sc_arg,
|
|
dnn_mem_t(scales_md, test_engine, /* prefill = */ true));
|
|
};
|
|
|
|
for (const auto &exec_arg : supported_exec_args) {
|
|
if (exec_arg == DNNL_ARG_MULTIPLE_SRC) {
|
|
// `DNNL_ARG_MULTIPLE_SRC` corresponds to a pack of inputs.
|
|
const auto n_inputs = query_n_inputs(const_pd);
|
|
for (int i = 0; i < n_inputs; i++) {
|
|
const auto i_exec_arg = exec_arg + i;
|
|
if (!sc.is_def(i_exec_arg)) append_scales(i_exec_arg);
|
|
}
|
|
} else {
|
|
if (!sc.is_def(exec_arg)) append_scales(exec_arg);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Zero points.
|
|
if (!prb->attr.zero_points.is_def()) {
|
|
const auto &zp = prb->attr.zero_points;
|
|
|
|
const auto &wei_md = query_md(const_pd, DNNL_ARG_WEIGHTS);
|
|
|
|
const auto append_zero_points = [&](int exec_arg) {
|
|
const int exec_zp_arg = DNNL_ARG_ATTR_ZERO_POINTS | exec_arg;
|
|
const auto &e = zp.get(exec_arg);
|
|
int64_t ndims = 1;
|
|
dims_t dims = {};
|
|
const auto mask
|
|
= zp.get_mask(exec_arg, prim_kind, query_md_ndims(wei_md));
|
|
const auto &groups = e.groups;
|
|
|
|
if (mask > 0) {
|
|
const auto &md = query_md(const_pd, exec_arg);
|
|
if (has_runtime_dims(md)) {
|
|
const auto prb_md = prb->get_md(exec_arg);
|
|
dims = md2dims(prb_md, mask, false, groups);
|
|
ndims = static_cast<int>(dims.size());
|
|
} else {
|
|
dims = md2dims(md, mask, false, groups);
|
|
ndims = static_cast<int>(dims.size());
|
|
}
|
|
} else {
|
|
dims = {1};
|
|
ndims = 1;
|
|
}
|
|
auto zp_md = dnn_mem_t::init_md(ndims, dims.data(), e.dt, tag::abx);
|
|
mem_map.emplace(exec_zp_arg,
|
|
dnn_mem_t(zp_md, test_engine, /* prefill = */ true));
|
|
};
|
|
|
|
for (const auto &exec_arg : supported_exec_args) {
|
|
if (exec_arg == DNNL_ARG_MULTIPLE_SRC) {
|
|
// `DNNL_ARG_MULTIPLE_SRC` corresponds to a pack of inputs.
|
|
const auto n_inputs = query_n_inputs(const_pd);
|
|
for (int i = 0; i < n_inputs; i++) {
|
|
const auto i_exec_arg = exec_arg + i;
|
|
if (!zp.is_def(i_exec_arg)) append_zero_points(i_exec_arg);
|
|
}
|
|
} else {
|
|
if (!zp.is_def(exec_arg)) append_zero_points(exec_arg);
|
|
}
|
|
}
|
|
}
|
|
|
|
// rounding mode
|
|
if (!prb->attr.rounding_mode.is_def()) {
|
|
int64_t count = 1;
|
|
auto seed_md = dnn_mem_t::init_md(1, &count, dnnl_s32, tag::abx);
|
|
mem_map.emplace(DNNL_ARG_ATTR_ROUNDING_SEED,
|
|
dnn_mem_t(seed_md, test_engine, /* prefill = */ true));
|
|
}
|
|
}
|
|
|
|
void erase_unused_args(
|
|
dnn_mem_map_t &ref_mem_map, const dnn_mem_map_t &mem_map);
|
|
|
|
int update_ref_mem_map_from_prim(dnnl_primitive_t prim_ref,
|
|
const dnn_mem_t &library_mem, dnn_mem_map_t &ref_mem_map, int exec_arg,
|
|
dnnl_data_type_t swapped_dt);
|
|
|
|
int init_ref_memory_args_default_case(int exec_arg, dnn_mem_t &mem,
|
|
dnn_mem_t &ref_mem, const attr_t &attr, res_t *res,
|
|
const std::unordered_map<int, fill_cfg_t> &fill_cfg_map = {});
|
|
|
|
int check_bitwise(dnnl_primitive_t prim, const std::vector<data_kind_t> &kinds,
|
|
const args_t &args, const attr_t &attr, bool inplace, res_t *res);
|
|
|
|
template <typename prb_t>
|
|
int init_prim_ref_common(benchdnn_dnnl_wrapper_t<dnnl_primitive_t> &prim_ref,
|
|
const prb_t *prb_cpu, res_t *res) {
|
|
|
|
init_pd_args_t<prb_t> init_pd_args(
|
|
/* res = */ nullptr, get_cpu_engine(), prb_cpu, prb_cpu->dir,
|
|
/* hint = */ nullptr, /* src_md = */ nullptr);
|
|
init_pd(init_pd_args);
|
|
|
|
benchdnn_dnnl_wrapper_t<dnnl_primitive_desc_t> pdw;
|
|
// `is_service_prim=true` prevents from filtering the implementation
|
|
// by name which is intended through a `get_prim_ref_impl_filter()`.
|
|
// As `fetch_impl` doesn't have any further logic related to it, it's
|
|
// safe to set it to `false`.
|
|
fetch_impl(pdw, init_pd_args, get_prim_ref_impl_filter(),
|
|
/* res = */ nullptr,
|
|
/* is_service_prim = */ false);
|
|
|
|
// Prim desc wasn't created - try the next set...
|
|
if (!pdw) return FAIL;
|
|
|
|
dnnl_primitive_t prim_ref_ptr {};
|
|
auto st = dnnl_primitive_create(&prim_ref_ptr, pdw);
|
|
// Primitive wasn't created - try the next set...
|
|
if (st != dnnl_success) return FAIL;
|
|
|
|
BENCHDNN_PRINT(5, "CPU reference oneDNN implementation: %s\n",
|
|
query_impl_info(pdw).c_str());
|
|
|
|
res->prim_ref_repro = prb_cpu->str();
|
|
// Replace engine kind for repro line from GPU to CPU.
|
|
const auto eng_pos = res->prim_ref_repro.find("engine=gpu");
|
|
if (eng_pos != std::string::npos) {
|
|
// Replace `g` in `gpu` with `c`
|
|
res->prim_ref_repro[eng_pos + 7] = 'c';
|
|
}
|
|
|
|
// Remove `--impl=XXX` as it doesn't affect prim_ref.
|
|
const auto impl_pos = res->prim_ref_repro.find("--impl=");
|
|
if (impl_pos != std::string::npos) {
|
|
// Search for the next space starting from `impl_pos` as names' length
|
|
// is variadic.
|
|
const auto end_impl_pos
|
|
= res->prim_ref_repro.find_first_of(" ", impl_pos);
|
|
assert(end_impl_pos != std::string::npos);
|
|
// `+ 1` is for extra space.
|
|
res->prim_ref_repro.erase(impl_pos, end_impl_pos - impl_pos + 1);
|
|
}
|
|
|
|
prim_ref.reset(prim_ref_ptr);
|
|
return OK;
|
|
}
|
|
|
|
#endif
|