api, tests, examples: add const_dnnl_memory_desc_t alias

This commit is contained in:
Denis Samoilov
2022-09-20 11:50:34 -07:00
parent 257baf0d99
commit 421979c8f2
10 changed files with 407 additions and 405 deletions

View File

@ -116,14 +116,14 @@ static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims,
}
dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in
const dnnl_memory_desc_t *prim_memory_md, // in
const_dnnl_memory_desc_t prim_memory_md, // in
dnnl_engine_t prim_engine, // in: primitive's engine
int dir_is_user_to_prim, // in: user -> prim or prim -> user
dnnl_memory_t *prim_memory, // out: primitive's memory created
dnnl_primitive_t *reorder, // out: reorder primitive created
uint32_t *net_index, // primitive index in net (inc if reorder created)
dnnl_primitive_t *net, args_t *net_args) { // net params
const dnnl_memory_desc_t *user_memory_md;
const_dnnl_memory_desc_t user_memory_md;
dnnl_memory_get_memory_desc(*user_memory, &user_memory_md);
dnnl_engine_t user_mem_engine;
@ -238,7 +238,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
conv_internal_dst_memory;
// create memory for dst data, we don't need reorder it to user data
const dnnl_memory_desc_t *dst_md
const_dnnl_memory_desc_t dst_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(
&conv_internal_dst_memory, dst_md, engine, DNNL_MEMORY_ALLOCATE));
@ -247,12 +247,12 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// if required
dnnl_primitive_t conv_reorder_src, conv_reorder_weights;
const dnnl_memory_desc_t *src_md
const_dnnl_memory_desc_t src_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0);
CHECK(prepare_reorder(&conv_user_src_memory, src_md, engine, 1,
&conv_internal_src_memory, &conv_reorder_src, &n, net, net_args));
const dnnl_memory_desc_t *weights_md
const_dnnl_memory_desc_t weights_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0);
CHECK(prepare_reorder(&conv_user_weights_memory, weights_md, engine, 1,
&conv_internal_weights_memory, &conv_reorder_weights, &n, net,
@ -282,7 +282,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// create relu memory descriptor on dst memory descriptor
// from previous primitive
const dnnl_memory_desc_t *relu_src_md
const_dnnl_memory_desc_t relu_src_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
// create a relu
@ -292,7 +292,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
NULL));
dnnl_memory_t relu_dst_memory;
const dnnl_memory_desc_t *relu_dst_md
const_dnnl_memory_desc_t relu_dst_md
= dnnl_primitive_desc_query_md(relu_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(
&relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
@ -319,7 +319,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// create lrn src memory descriptor using dst memory descriptor
// from previous primitive
const dnnl_memory_desc_t *lrn_src_md = relu_dst_md;
const_dnnl_memory_desc_t lrn_src_md = relu_dst_md;
// create a lrn primitive descriptor
dnnl_primitive_desc_t lrn_pd;
@ -329,12 +329,12 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// create primitives for lrn dst and workspace memory
dnnl_memory_t lrn_dst_memory;
const dnnl_memory_desc_t *lrn_dst_md
const_dnnl_memory_desc_t lrn_dst_md
= dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(
&lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
dnnl_memory_t lrn_ws_memory;
const dnnl_memory_desc_t *lrn_ws_md
const_dnnl_memory_desc_t lrn_ws_md
= dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0);
CHECK(dnnl_memory_create(
&lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
@ -364,7 +364,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// create pooling memory descriptor on dst descriptor
// from previous primitive
const dnnl_memory_desc_t *pool_src_md = lrn_dst_md;
const_dnnl_memory_desc_t pool_src_md = lrn_dst_md;
// create descriptors for dst pooling data
dnnl_memory_desc_t pool_dst_any_md;
@ -385,7 +385,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// create memory for workspace
dnnl_memory_t pool_ws_memory;
const dnnl_memory_desc_t *pool_ws_md
const_dnnl_memory_desc_t pool_ws_md
= dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0);
CHECK(dnnl_memory_create(
&pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
@ -396,7 +396,7 @@ void simple_net(dnnl_engine_kind_t engine_kind) {
// if required
dnnl_primitive_t pool_reorder_dst;
dnnl_memory_t pool_internal_dst_memory;
const dnnl_memory_desc_t *pool_dst_md
const_dnnl_memory_desc_t pool_dst_md
= dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0);
n += 1; // tentative workaround: preserve space for pooling that should
// happen before the reorder

View File

@ -101,14 +101,14 @@ static void init_data_memory(uint32_t dim, const dnnl_dim_t *dims,
}
dnnl_status_t prepare_reorder(dnnl_memory_t *user_memory, // in
const dnnl_memory_desc_t *prim_memory_md, // in
const_dnnl_memory_desc_t prim_memory_md, // in
dnnl_engine_t prim_engine, // in: primitive's engine
int dir_is_user_to_prim, // in: user -> prim or prim -> user
dnnl_memory_t *prim_memory, // out: primitive's memory created
dnnl_primitive_t *reorder, // out: reorder primitive created
uint32_t *net_index, // primitive index in net (inc if reorder created)
dnnl_primitive_t *net, args_t *net_args) { // net params
const dnnl_memory_desc_t *user_memory_md;
const_dnnl_memory_desc_t user_memory_md;
dnnl_memory_get_memory_desc(*user_memory, &user_memory_md);
dnnl_engine_t user_mem_engine;
@ -228,7 +228,7 @@ void simple_net() {
conv_internal_dst_memory;
// create memory for dst data, we don't need to reorder it to user data
const dnnl_memory_desc_t *conv_dst_md
const_dnnl_memory_desc_t conv_dst_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(&conv_internal_dst_memory, conv_dst_md, engine,
DNNL_MEMORY_ALLOCATE));
@ -237,13 +237,13 @@ void simple_net() {
// if required
dnnl_primitive_t conv_reorder_src, conv_reorder_weights;
const dnnl_memory_desc_t *conv_src_md
const_dnnl_memory_desc_t conv_src_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_src_md, 0);
CHECK(prepare_reorder(&conv_user_src_memory, conv_src_md, engine, 1,
&conv_internal_src_memory, &conv_reorder_src, &n_fwd, net_fwd,
net_fwd_args));
const dnnl_memory_desc_t *conv_weights_md
const_dnnl_memory_desc_t conv_weights_md
= dnnl_primitive_desc_query_md(conv_pd, dnnl_query_weights_md, 0);
CHECK(prepare_reorder(&conv_user_weights_memory, conv_weights_md, engine, 1,
&conv_internal_weights_memory, &conv_reorder_weights, &n_fwd,
@ -276,7 +276,7 @@ void simple_net() {
// keep memory format of source same as the format of convolution
// output in order to avoid reorder
const dnnl_memory_desc_t *relu_src_md = conv_dst_md;
const_dnnl_memory_desc_t relu_src_md = conv_dst_md;
// create a relu primitive descriptor
dnnl_primitive_desc_t relu_pd;
@ -286,7 +286,7 @@ void simple_net() {
// create relu dst memory
dnnl_memory_t relu_dst_memory;
const dnnl_memory_desc_t *relu_dst_md
const_dnnl_memory_desc_t relu_dst_md
= dnnl_primitive_desc_query_md(relu_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(
&relu_dst_memory, relu_dst_md, engine, DNNL_MEMORY_ALLOCATE));
@ -314,7 +314,7 @@ void simple_net() {
// create lrn src memory descriptor using dst memory descriptor
// from previous primitive
const dnnl_memory_desc_t *lrn_src_md = relu_dst_md;
const_dnnl_memory_desc_t lrn_src_md = relu_dst_md;
// create a lrn primitive descriptor
dnnl_primitive_desc_t lrn_pd;
@ -325,14 +325,14 @@ void simple_net() {
// create primitives for lrn dst and workspace memory
dnnl_memory_t lrn_dst_memory, lrn_ws_memory;
const dnnl_memory_desc_t *lrn_dst_md
const_dnnl_memory_desc_t lrn_dst_md
= dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_dst_md, 0);
CHECK(dnnl_memory_create(
&lrn_dst_memory, lrn_dst_md, engine, DNNL_MEMORY_ALLOCATE));
// create workspace only in training and only for forward primitive
// query lrn_pd for workspace, this memory will be shared with forward lrn
const dnnl_memory_desc_t *lrn_ws_md
const_dnnl_memory_desc_t lrn_ws_md
= dnnl_primitive_desc_query_md(lrn_pd, dnnl_query_workspace_md, 0);
CHECK(dnnl_memory_create(
&lrn_ws_memory, lrn_ws_md, engine, DNNL_MEMORY_ALLOCATE));
@ -371,7 +371,7 @@ void simple_net() {
{
// create pooling src memory descriptor using dst descriptor
// from previous primitive
const dnnl_memory_desc_t *pool_src_md = lrn_dst_md;
const_dnnl_memory_desc_t pool_src_md = lrn_dst_md;
// create descriptors for dst pooling data
dnnl_memory_desc_t pool_dst_md;
@ -386,7 +386,7 @@ void simple_net() {
// create memory for workspace
dnnl_memory_t pool_ws_memory;
const dnnl_memory_desc_t *pool_ws_md
const_dnnl_memory_desc_t pool_ws_md
= dnnl_primitive_desc_query_md(pool_pd, dnnl_query_workspace_md, 0);
CHECK(dnnl_memory_create(
&pool_ws_memory, pool_ws_md, engine, DNNL_MEMORY_ALLOCATE));
@ -395,7 +395,7 @@ void simple_net() {
// if required
dnnl_primitive_t pool_reorder_dst;
dnnl_memory_t pool_internal_dst_memory;
const dnnl_memory_desc_t *pool_dst_md
const_dnnl_memory_desc_t pool_dst_md
= dnnl_primitive_desc_query_md(pool_pd, dnnl_query_dst_md, 0);
n_fwd += 1; // tentative workaround: preserve space for pooling that should
// happen before the reorder
@ -437,10 +437,10 @@ void simple_net() {
// Pooling Backward
// pooling diff src memory descriptor
const dnnl_memory_desc_t *pool_diff_src_md = lrn_dst_md;
const_dnnl_memory_desc_t pool_diff_src_md = lrn_dst_md;
// pooling diff dst memory descriptor
const dnnl_memory_desc_t *pool_diff_dst_md = pool_dst_md;
const_dnnl_memory_desc_t pool_diff_dst_md = pool_dst_md;
// backward primitive descriptor needs to hint forward descriptor
dnnl_primitive_desc_t pool_bwd_pd;
@ -479,7 +479,7 @@ void simple_net() {
n_bwd++;
// Backward lrn
const dnnl_memory_desc_t *lrn_diff_dst_md = pool_diff_src_md;
const_dnnl_memory_desc_t lrn_diff_dst_md = pool_diff_src_md;
// create backward lrn descriptor
dnnl_primitive_desc_t lrn_bwd_pd;
@ -489,7 +489,7 @@ void simple_net() {
// create memory for lrn diff src
dnnl_memory_t lrn_diff_src_memory;
const dnnl_memory_desc_t *lrn_diff_src_md = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t lrn_diff_src_md = dnnl_primitive_desc_query_md(
lrn_bwd_pd, dnnl_query_diff_src_md, 0);
CHECK(dnnl_memory_create(&lrn_diff_src_memory, lrn_diff_src_md, engine,
DNNL_MEMORY_ALLOCATE));
@ -508,7 +508,7 @@ void simple_net() {
n_bwd++;
// Backward relu
const dnnl_memory_desc_t *relu_diff_dst_md = lrn_diff_src_md;
const_dnnl_memory_desc_t relu_diff_dst_md = lrn_diff_src_md;
// create backward relu descriptor
dnnl_primitive_desc_t relu_bwd_pd;
@ -518,7 +518,7 @@ void simple_net() {
// create memory for relu diff src
dnnl_memory_t relu_diff_src_memory;
const dnnl_memory_desc_t *relu_diff_src_md = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t relu_diff_src_md = dnnl_primitive_desc_query_md(
relu_bwd_pd, dnnl_query_diff_src_md, 0);
CHECK(dnnl_memory_create(&relu_diff_src_memory, relu_diff_src_md, engine,
DNNL_MEMORY_ALLOCATE));
@ -581,7 +581,7 @@ void simple_net() {
// format chosen by backward convolution
dnnl_primitive_t conv_bwd_reorder_src;
dnnl_memory_t conv_bwd_internal_src_memory;
const dnnl_memory_desc_t *conv_diff_src_md = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t conv_diff_src_md = dnnl_primitive_desc_query_md(
conv_bwd_weights_pd, dnnl_query_src_md, 0);
CHECK(prepare_reorder(&conv_src_memory, conv_diff_src_md, engine, 1,
&conv_bwd_internal_src_memory, &conv_bwd_reorder_src, &n_bwd,
@ -595,7 +595,7 @@ void simple_net() {
// and format preferred by conv_diff_weights
dnnl_primitive_t conv_reorder_diff_dst;
dnnl_memory_t conv_internal_diff_dst_memory;
const dnnl_memory_desc_t *conv_diff_dst_md = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t conv_diff_dst_md = dnnl_primitive_desc_query_md(
conv_bwd_weights_pd, dnnl_query_diff_dst_md, 0);
CHECK(prepare_reorder(&relu_diff_src_memory, conv_diff_dst_md, engine, 1,
@ -609,7 +609,7 @@ void simple_net() {
// create reorder primitives for conv diff weights memory
dnnl_primitive_t conv_reorder_diff_weights;
dnnl_memory_t conv_internal_diff_weights_memory;
const dnnl_memory_desc_t *conv_diff_weights_md
const_dnnl_memory_desc_t conv_diff_weights_md
= dnnl_primitive_desc_query_md(
conv_bwd_weights_pd, dnnl_query_diff_weights_md, 0);
n_bwd += 1; // tentative workaround: preserve space for conv_bwd_weights
@ -626,7 +626,7 @@ void simple_net() {
// create memory for diff bias memory
dnnl_memory_t conv_diff_bias_memory;
const dnnl_memory_desc_t *conv_diff_bias_md = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t conv_diff_bias_md = dnnl_primitive_desc_query_md(
conv_bwd_weights_pd, dnnl_query_diff_weights_md, 1);
CHECK(dnnl_memory_create(&conv_diff_bias_memory, conv_diff_bias_md, engine,
DNNL_MEMORY_ALLOCATE));

View File

@ -95,7 +95,7 @@ static inline const char *engine_kind2str_upper(dnnl_engine_kind_t kind) {
static inline void read_from_dnnl_memory(void *handle, dnnl_memory_t mem) {
dnnl_engine_t eng;
dnnl_engine_kind_t eng_kind;
const dnnl_memory_desc_t *md;
const_dnnl_memory_desc_t md;
if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s", "handle is NULL.");
@ -133,7 +133,7 @@ static inline void read_from_dnnl_memory(void *handle, dnnl_memory_t mem) {
static inline void write_to_dnnl_memory(void *handle, dnnl_memory_t mem) {
dnnl_engine_t eng;
dnnl_engine_kind_t eng_kind;
const dnnl_memory_desc_t *md;
const_dnnl_memory_desc_t md;
if (!handle) COMPLAIN_EXAMPLE_ERROR_AND_EXIT("%s", "handle is NULL.");

View File

@ -142,7 +142,7 @@ dnnl_status_t DNNL_API dnnl_primitive_desc_query(
/// needed.
/// @returns NULL in case of any error.
///
const dnnl_memory_desc_t DNNL_API *dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t DNNL_API dnnl_primitive_desc_query_md(
const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what,
int index);
@ -743,7 +743,7 @@ dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw(
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc);
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src1_desc);
/// Returns the parameters of a binary post-op.
///
@ -757,7 +757,7 @@ dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops,
/// post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(
const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind,
const dnnl_memory_desc_t **src1_desc);
const_dnnl_memory_desc_t *src1_desc);
/// Appends a prelu forward post-op.
///
@ -868,7 +868,7 @@ dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(
dnnl_memory_desc_t *memory_desc,
const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims,
const_dnnl_memory_desc_t parent_memory_desc, const dnnl_dims_t dims,
const dnnl_dims_t offsets);
/// Initializes a memory descriptor by reshaping an existing one. The new
@ -912,7 +912,7 @@ dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
dnnl_memory_desc_t *out_memory_desc,
const dnnl_memory_desc_t *in_memory_desc, int ndims,
const_dnnl_memory_desc_t in_memory_desc, int ndims,
const dnnl_dims_t dims);
/// Initializes a memory descriptor by permuting axes in an existing one.
@ -957,7 +957,7 @@ dnnl_status_t DNNL_API dnnl_memory_desc_reshape(
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
dnnl_memory_desc_t *out_memory_desc,
const dnnl_memory_desc_t *in_memory_desc, const int *permutation);
const_dnnl_memory_desc_t in_memory_desc, const int *permutation);
/// Queries a memory descriptor for various pieces of information.
///
@ -1006,7 +1006,7 @@ dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_desc_query(
const dnnl_memory_desc_t *memory_desc, dnnl_query_t what, void *result);
const_dnnl_memory_desc_t memory_desc, dnnl_query_t what, void *result);
/// Compares two memory descriptors.
///
@ -1018,15 +1018,14 @@ dnnl_status_t DNNL_API dnnl_memory_desc_query(
/// @returns 1 if the descriptors are the same.
/// @returns 0 if the descriptors are different.
int DNNL_API dnnl_memory_desc_equal(
const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs);
const_dnnl_memory_desc_t lhs, const_dnnl_memory_desc_t rhs);
/// Returns the size of a memory descriptor.
///
/// @param memory_desc Memory descriptor.
/// @returns The number of bytes required for memory described by a memory
/// descriptor.
size_t DNNL_API dnnl_memory_desc_get_size(
const dnnl_memory_desc_t *memory_desc);
size_t DNNL_API dnnl_memory_desc_get_size(const_dnnl_memory_desc_t memory_desc);
/// Returns the size of data type.
///
@ -1055,7 +1054,7 @@ size_t DNNL_API dnnl_data_type_size(dnnl_data_type_t data_type);
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine,
const_dnnl_memory_desc_t memory_desc, dnnl_engine_t engine,
void *handle);
/// Returns the memory descriptor for a memory object.
@ -1065,7 +1064,7 @@ dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory,
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(
const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc);
const_dnnl_memory_t memory, const_dnnl_memory_desc_t *memory_desc);
/// Returns the engine of a memory object.
///
@ -1169,8 +1168,8 @@ dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory);
/// otherwise.
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
dnnl_primitive_desc_t *reorder_primitive_desc,
const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine,
const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine,
const_dnnl_memory_desc_t src_desc, dnnl_engine_t src_engine,
const_dnnl_memory_desc_t dst_desc, dnnl_engine_t dst_engine,
const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_reorder
@ -1194,8 +1193,8 @@ dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
dnnl_primitive_desc_t *concat_primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension,
const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr);
const_dnnl_memory_desc_t dst_desc, int n, int concat_dimension,
const_dnnl_memory_desc_t src_descs, const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_concat
@ -1216,8 +1215,8 @@ dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
dnnl_primitive_desc_t *sum_primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *dst_desc, int n, const float *scales,
const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr);
const_dnnl_memory_desc_t dst_desc, int n, const float *scales,
const_dnnl_memory_desc_t src_descs, const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_sum
@ -1250,8 +1249,8 @@ dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc,
const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src0_desc,
const_dnnl_memory_desc_t src1_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_binary
@ -1299,9 +1298,9 @@ dnnl_status_t DNNL_API dnnl_binary_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
const dnnl_dims_t strides, const dnnl_dims_t dilates,
const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
const_dnnl_primitive_attr_t attr);
@ -1342,9 +1341,9 @@ dnnl_status_t DNNL_API dnnl_convolution_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1387,10 +1386,10 @@ dnnl_status_t DNNL_API dnnl_convolution_backward_data_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *diff_weights_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t diff_weights_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1439,9 +1438,9 @@ dnnl_status_t DNNL_API dnnl_convolution_backward_weights_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
const dnnl_dims_t strides, const dnnl_dims_t dilates,
const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
const_dnnl_primitive_attr_t attr);
@ -1481,9 +1480,9 @@ dnnl_status_t DNNL_API dnnl_deconvolution_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1527,10 +1526,10 @@ dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_primitive_desc_create(
dnnl_status_t DNNL_API
dnnl_deconvolution_backward_weights_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *diff_weights_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t diff_weights_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
const dnnl_dims_t dilates, const dnnl_dims_t padding_l,
const dnnl_dims_t padding_r, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1554,7 +1553,7 @@ dnnl_deconvolution_backward_weights_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t data_desc,
int axis, dnnl_dim_t group_size, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a shuffle backward propagation primitive
@ -1571,7 +1570,7 @@ dnnl_status_t DNNL_API dnnl_shuffle_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *diff_data_desc, int axis,
const_dnnl_memory_desc_t diff_data_desc, int axis,
dnnl_dim_t group_size, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1598,7 +1597,7 @@ dnnl_status_t DNNL_API dnnl_shuffle_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *data_desc, float alpha, float beta,
const_dnnl_memory_desc_t data_desc, float alpha, float beta,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an eltwise backward propagation
@ -1620,8 +1619,8 @@ dnnl_status_t DNNL_API dnnl_eltwise_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc,
const dnnl_memory_desc_t *data_desc, float alpha, float beta,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_data_desc,
const_dnnl_memory_desc_t data_desc, float alpha, float beta,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1647,7 +1646,7 @@ dnnl_status_t DNNL_API dnnl_eltwise_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc,
const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
int softmax_axis, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a softmax backward propagation primitive.
@ -1667,9 +1666,9 @@ dnnl_status_t DNNL_API dnnl_softmax_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *diff_dst_desc,
const dnnl_memory_desc_t *dst_desc, int softmax_axis,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t diff_dst_desc,
const_dnnl_memory_desc_t dst_desc, int softmax_axis,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1710,7 +1709,7 @@ dnnl_status_t DNNL_API dnnl_softmax_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc,
const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t dst_desc,
const dnnl_dims_t strides, const dnnl_dims_t kernel,
const dnnl_dims_t dilation, const dnnl_dims_t padding_l,
const dnnl_dims_t padding_r, const_dnnl_primitive_attr_t attr);
@ -1746,8 +1745,8 @@ dnnl_status_t DNNL_API dnnl_pooling_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t diff_dst_desc, const dnnl_dims_t strides,
const dnnl_dims_t kernel, const dnnl_dims_t dilation,
const dnnl_dims_t padding_l, const dnnl_dims_t padding_r,
const_dnnl_primitive_desc_t hint_forward_primitive_desc,
@ -1776,8 +1775,8 @@ dnnl_status_t DNNL_API dnnl_pooling_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc,
const dnnl_memory_desc_t *weights_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t data_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a PReLU (leaky ReLU with trainable
@ -1801,10 +1800,10 @@ dnnl_status_t DNNL_API dnnl_prelu_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *data_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *diff_data_desc,
const dnnl_memory_desc_t *diff_weights_desc,
const_dnnl_memory_desc_t data_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t diff_data_desc,
const_dnnl_memory_desc_t diff_weights_desc,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1832,7 +1831,7 @@ dnnl_status_t DNNL_API dnnl_prelu_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
const_dnnl_memory_desc_t data_desc, dnnl_dim_t local_size, float alpha,
float beta, float k, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an LRN backward propagation primitive.
@ -1854,8 +1853,8 @@ dnnl_status_t DNNL_API dnnl_lrn_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc,
const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t diff_data_desc,
const_dnnl_memory_desc_t data_desc, dnnl_dim_t local_size, float alpha,
float beta, float k, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1883,7 +1882,7 @@ dnnl_status_t DNNL_API dnnl_lrn_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t data_desc,
float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a batch normalization backward
@ -1909,8 +1908,8 @@ dnnl_status_t DNNL_API dnnl_batch_normalization_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc,
const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_data_desc,
const_dnnl_memory_desc_t data_desc, float epsilon, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -1944,8 +1943,8 @@ dnnl_status_t DNNL_API dnnl_batch_normalization_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *dst_desc, const dnnl_memory_desc_t *stat_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc, const_dnnl_memory_desc_t stat_desc,
float epsilon, unsigned flags, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a layer normalization backward
@ -1977,9 +1976,9 @@ dnnl_status_t DNNL_API dnnl_layer_normalization_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *diff_dst_desc,
const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *stat_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t diff_dst_desc,
const_dnnl_memory_desc_t src_desc, const_dnnl_memory_desc_t stat_desc,
float epsilon, unsigned flags, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2010,9 +2009,9 @@ dnnl_status_t DNNL_API dnnl_layer_normalization_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
dnnl_prop_kind_t prop_kind, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an inner product backward propagation
@ -2034,9 +2033,9 @@ dnnl_status_t DNNL_API dnnl_inner_product_forward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *diff_dst_desc,
const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t diff_dst_desc,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2063,10 +2062,10 @@ dnnl_status_t DNNL_API dnnl_inner_product_backward_data_primitive_desc_create(
dnnl_status_t DNNL_API
dnnl_inner_product_backward_weights_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *diff_weights_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_desc,
const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t diff_weights_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_desc,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2282,13 +2281,13 @@ dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc, unsigned flags, float alpha,
float beta, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for vanilla RNN backward propagation
@ -2349,20 +2348,20 @@ dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation,
const dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
float alpha, float beta, const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2414,15 +2413,15 @@ dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an LSTM (with or without peephole)
@ -2475,16 +2474,16 @@ dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create_v2(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *weights_peephole_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an LSTM (with or without peephole and
@ -2544,17 +2543,17 @@ dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create_v2(
dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create_v3(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *weights_peephole_desc,
const dnnl_memory_desc_t *weights_projection_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for an LSTM backward propagation primitive.
@ -2624,24 +2623,24 @@ dnnl_status_t DNNL_API dnnl_lstm_forward_primitive_desc_create_v3(
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_src_iter_c_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc,
const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_src_iter_c_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2716,26 +2715,26 @@ dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create_v2(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *weights_peephole_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_src_iter_c_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_weights_peephole_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc,
const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_src_iter_c_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_weights_peephole_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2819,28 +2818,28 @@ dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create_v2(
dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create_v3(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *src_iter_c_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *weights_peephole_desc,
const dnnl_memory_desc_t *weights_projection_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *dst_iter_c_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_src_iter_c_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_weights_peephole_desc,
const dnnl_memory_desc_t *diff_weights_projection_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc,
const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t src_iter_c_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t weights_peephole_desc,
const_dnnl_memory_desc_t weights_projection_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t dst_iter_c_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_src_iter_c_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_weights_peephole_desc,
const_dnnl_memory_desc_t diff_weights_projection_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc,
const_dnnl_memory_desc_t diff_dst_iter_c_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2883,13 +2882,13 @@ dnnl_status_t DNNL_API dnnl_lstm_backward_primitive_desc_create_v3(
dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for GRU backward propagation primitive.
@ -2944,20 +2943,20 @@ dnnl_status_t DNNL_API dnnl_gru_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -2996,13 +2995,13 @@ dnnl_status_t DNNL_API dnnl_gru_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for LBR GRU backward propagation primitive.
@ -3057,20 +3056,20 @@ dnnl_status_t DNNL_API dnnl_lbr_gru_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -3114,14 +3113,14 @@ dnnl_status_t DNNL_API dnnl_lbr_gru_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *attention_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for AUGRU backward propagation primitive.
@ -3178,22 +3177,22 @@ dnnl_status_t DNNL_API dnnl_augru_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *attention_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_attention_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_attention_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -3233,14 +3232,14 @@ dnnl_status_t DNNL_API dnnl_augru_backward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *attention_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc, unsigned flags,
const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for LBR AUGRU backward propagation primitive.
@ -3297,22 +3296,22 @@ dnnl_status_t DNNL_API dnnl_lbr_augru_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction,
const dnnl_memory_desc_t *src_layer_desc,
const dnnl_memory_desc_t *src_iter_desc,
const dnnl_memory_desc_t *attention_desc,
const dnnl_memory_desc_t *weights_layer_desc,
const dnnl_memory_desc_t *weights_iter_desc,
const dnnl_memory_desc_t *bias_desc,
const dnnl_memory_desc_t *dst_layer_desc,
const dnnl_memory_desc_t *dst_iter_desc,
const dnnl_memory_desc_t *diff_src_layer_desc,
const dnnl_memory_desc_t *diff_src_iter_desc,
const dnnl_memory_desc_t *diff_attention_desc,
const dnnl_memory_desc_t *diff_weights_layer_desc,
const dnnl_memory_desc_t *diff_weights_iter_desc,
const dnnl_memory_desc_t *diff_bias_desc,
const dnnl_memory_desc_t *diff_dst_layer_desc,
const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags,
const_dnnl_memory_desc_t src_layer_desc,
const_dnnl_memory_desc_t src_iter_desc,
const_dnnl_memory_desc_t attention_desc,
const_dnnl_memory_desc_t weights_layer_desc,
const_dnnl_memory_desc_t weights_iter_desc,
const_dnnl_memory_desc_t bias_desc,
const_dnnl_memory_desc_t dst_layer_desc,
const_dnnl_memory_desc_t dst_iter_desc,
const_dnnl_memory_desc_t diff_src_layer_desc,
const_dnnl_memory_desc_t diff_src_iter_desc,
const_dnnl_memory_desc_t diff_attention_desc,
const_dnnl_memory_desc_t diff_weights_layer_desc,
const_dnnl_memory_desc_t diff_weights_iter_desc,
const_dnnl_memory_desc_t diff_bias_desc,
const_dnnl_memory_desc_t diff_dst_layer_desc,
const_dnnl_memory_desc_t diff_dst_iter_desc, unsigned flags,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -3336,9 +3335,9 @@ dnnl_status_t DNNL_API dnnl_lbr_augru_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *weights_desc,
const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc,
const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t weights_desc,
const_dnnl_memory_desc_t bias_desc, const_dnnl_memory_desc_t dst_desc,
const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_matmul
@ -3368,8 +3367,8 @@ dnnl_status_t DNNL_API dnnl_matmul_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind,
const float *factors, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *dst_desc, const_dnnl_primitive_attr_t attr);
const float *factors, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc, const_dnnl_primitive_attr_t attr);
/// Creates a primitive descriptor for a resampling backward propagation
/// primitive.
@ -3390,8 +3389,8 @@ dnnl_status_t DNNL_API dnnl_resampling_forward_primitive_desc_create(
dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const float *factors,
const dnnl_memory_desc_t *diff_src_desc,
const dnnl_memory_desc_t *diff_dst_desc,
const_dnnl_memory_desc_t diff_src_desc,
const_dnnl_memory_desc_t diff_dst_desc,
const_dnnl_primitive_desc_t hint_fwd_pd,
const_dnnl_primitive_attr_t attr);
@ -3422,8 +3421,8 @@ dnnl_status_t DNNL_API dnnl_resampling_backward_primitive_desc_create(
/// otherwise.
dnnl_status_t DNNL_API dnnl_reduction_primitive_desc_create(
dnnl_primitive_desc_t *primitive_desc, dnnl_engine_t engine,
dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
const dnnl_memory_desc_t *dst_desc, float p, float eps,
dnnl_alg_kind_t alg_kind, const_dnnl_memory_desc_t src_desc,
const_dnnl_memory_desc_t dst_desc, float p, float eps,
const_dnnl_primitive_attr_t attr);
/// @} dnnl_api_reduction

View File

@ -2490,7 +2490,7 @@ struct memory : public handle<dnnl_memory_t> {
// This function mimics `handle::get()` and will be removed once
// `desc` is inherited from `handle`.
const dnnl_memory_desc_t *get() const { return &data; }
const_dnnl_memory_desc_t get() const { return &data; }
/// Constructs a zero (empty) memory descriptor. Such a memory
/// descriptor can be used to indicate absence of an argument.
@ -2887,7 +2887,7 @@ struct memory : public handle<dnnl_memory_t> {
/// Returns the associated memory descriptor.
desc get_desc() const {
const dnnl_memory_desc_t *cdesc;
const_dnnl_memory_desc_t cdesc;
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
"could not get a memory descriptor from a memory object");
return desc(*cdesc);
@ -3295,7 +3295,7 @@ struct post_ops : public handle<dnnl_post_ops_t> {
void get_params_binary(
int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
dnnl_alg_kind_t c_alg;
const dnnl_memory_desc_t *data;
const_dnnl_memory_desc_t data;
error::wrap_c_api(
dnnl_post_ops_get_params_binary(get(), index, &c_alg, &data),
"could not get parameters of a binary post-op");
@ -3968,7 +3968,7 @@ struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
const bool is_backward = get_prop_kind() != prop_kind::forward_training
&& get_prop_kind() != prop_kind::forward_inference;
const dnnl_memory_desc_t *md = dnnl_primitive_desc_query_md(get(),
const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
return status == dnnl_success
@ -4055,7 +4055,7 @@ struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
DNNL_THROW_ERROR(dnnl_invalid_arguments,
"memory descriptor query is invalid");
const dnnl_memory_desc_t *cdesc = dnnl_primitive_desc_query_md(
const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
get(), dnnl::convert_to_c(what), idx);
return cdesc ? memory::desc(*cdesc) : memory::desc();
}
@ -4256,7 +4256,7 @@ protected:
memory::dims query_dims(query what) const {
const bool is_backward = get_prop_kind() != prop_kind::forward_training
&& get_prop_kind() != prop_kind::forward_inference;
const dnnl_memory_desc_t *md = dnnl_primitive_desc_query_md(get(),
const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
const int nspatial_dims = md ? md->ndims - 2 : 0;
@ -4378,7 +4378,7 @@ protected:
return attr;
}
const dnnl_memory_desc_t *optional_arg(const memory::desc *md) {
const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
return md ? md->get() : nullptr;
}

View File

@ -1868,6 +1868,9 @@ typedef struct {
dnnl_memory_extra_desc_t extra;
} dnnl_memory_desc_t;
/// A memory descriptor handle.
typedef const dnnl_memory_desc_t *const_dnnl_memory_desc_t;
/// @struct dnnl_memory
/// An opaque structure to describe a memory.
struct dnnl_memory;
@ -2303,7 +2306,7 @@ typedef struct {
/// dnnl_query_*_f32 | float *
/// dnnl_query_*_f64 | double *
/// dnnl_query_*_str | const char **
/// dnnl_query_*_md | const #dnnl_memory_desc_t **
/// dnnl_query_*_md | #const_dnnl_memory_desc_t *
/// dnnl_query_*_pd | #const_dnnl_primitive_desc_t *
/// dnnl_query_cache_blob_id | const uint8_t **
/// dnnl_query_strides | const #dnnl_dims_t **

View File

@ -814,7 +814,7 @@ static int check_total_size(
return res->state == FAILED ? FAIL : OK;
}
static size_t get_md_size(const dnnl_memory_desc_t *md,
static size_t get_md_size(const_dnnl_memory_desc_t md,
bool add_ref_size = false, bool add_ref_out_size = false) {
const auto mem_size = dnnl_memory_desc_get_size(md);
// runtime mem size is not defined

View File

@ -245,7 +245,7 @@ dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
auto src_d = dnn_mem_t::init_md(
prb->ndims, prb->dims.data(), prb->dt[0], prb->tag[0]);
dnnl_memory_desc_t stat_d {};
const dnnl_memory_desc_t *stat_d_ptr = nullptr;
const_dnnl_memory_desc_t stat_d_ptr = nullptr;
if (prb->stat_tag != tag::undef) {
stat_d = dnn_mem_t::init_md(
prb->ndims - 1, prb->dims.data(), dnnl_f32, prb->stat_tag);

View File

@ -266,17 +266,17 @@ std::ostream &operator<<(std::ostream &s, const prb_t &prb) {
dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
const prb_t &prb, dnnl_prop_kind_t prop_kind,
const dnnl_memory_desc_t *src_layer_d,
const dnnl_memory_desc_t *src_iter_d,
const dnnl_memory_desc_t *src_iter_c_d,
const dnnl_memory_desc_t *attention_d,
const dnnl_memory_desc_t *weights_layer_d,
const dnnl_memory_desc_t *weights_iter_d,
const dnnl_memory_desc_t *weights_peephole_d,
const dnnl_memory_desc_t *weights_projection_d,
const dnnl_memory_desc_t *bias_d, const dnnl_memory_desc_t *dst_layer_d,
const dnnl_memory_desc_t *dst_iter_d,
const dnnl_memory_desc_t *dst_iter_c_d, dnnl_primitive_attr_t attr) {
const_dnnl_memory_desc_t src_layer_d,
const_dnnl_memory_desc_t src_iter_d,
const_dnnl_memory_desc_t src_iter_c_d,
const_dnnl_memory_desc_t attention_d,
const_dnnl_memory_desc_t weights_layer_d,
const_dnnl_memory_desc_t weights_iter_d,
const_dnnl_memory_desc_t weights_peephole_d,
const_dnnl_memory_desc_t weights_projection_d,
const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
const_dnnl_memory_desc_t dst_iter_d,
const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr) {
dnnl_alg_kind_t kind = alg2kind(prb.alg);
dnnl_alg_kind_t f = activation2kind(prb.activation);
@ -326,29 +326,29 @@ dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
const prb_t &prb, dnnl_prop_kind_t prop_kind,
const dnnl_memory_desc_t *src_layer_d,
const dnnl_memory_desc_t *src_iter_d,
const dnnl_memory_desc_t *src_iter_c_d,
const dnnl_memory_desc_t *attention_d,
const dnnl_memory_desc_t *weights_layer_d,
const dnnl_memory_desc_t *weights_iter_d,
const dnnl_memory_desc_t *weights_peephole_d,
const dnnl_memory_desc_t *weights_projection_d,
const dnnl_memory_desc_t *bias_d, const dnnl_memory_desc_t *dst_layer_d,
const dnnl_memory_desc_t *dst_iter_d,
const dnnl_memory_desc_t *dst_iter_c_d,
const dnnl_memory_desc_t *diff_src_layer_d,
const dnnl_memory_desc_t *diff_src_iter_d,
const dnnl_memory_desc_t *diff_src_iter_c_d,
const dnnl_memory_desc_t *diff_attention_d,
const dnnl_memory_desc_t *diff_weights_layer_d,
const dnnl_memory_desc_t *diff_weights_iter_d,
const dnnl_memory_desc_t *diff_weights_peephole_d,
const dnnl_memory_desc_t *diff_weights_projection_d,
const dnnl_memory_desc_t *diff_bias_d,
const dnnl_memory_desc_t *diff_dst_layer_d,
const dnnl_memory_desc_t *diff_dst_iter_d,
const dnnl_memory_desc_t *diff_dst_iter_c_d,
const_dnnl_memory_desc_t src_layer_d,
const_dnnl_memory_desc_t src_iter_d,
const_dnnl_memory_desc_t src_iter_c_d,
const_dnnl_memory_desc_t attention_d,
const_dnnl_memory_desc_t weights_layer_d,
const_dnnl_memory_desc_t weights_iter_d,
const_dnnl_memory_desc_t weights_peephole_d,
const_dnnl_memory_desc_t weights_projection_d,
const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
const_dnnl_memory_desc_t dst_iter_d,
const_dnnl_memory_desc_t dst_iter_c_d,
const_dnnl_memory_desc_t diff_src_layer_d,
const_dnnl_memory_desc_t diff_src_iter_d,
const_dnnl_memory_desc_t diff_src_iter_c_d,
const_dnnl_memory_desc_t diff_attention_d,
const_dnnl_memory_desc_t diff_weights_layer_d,
const_dnnl_memory_desc_t diff_weights_iter_d,
const_dnnl_memory_desc_t diff_weights_peephole_d,
const_dnnl_memory_desc_t diff_weights_projection_d,
const_dnnl_memory_desc_t diff_bias_d,
const_dnnl_memory_desc_t diff_dst_layer_d,
const_dnnl_memory_desc_t diff_dst_iter_d,
const_dnnl_memory_desc_t diff_dst_iter_c_d,
const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr) {
dnnl_alg_kind_t kind = alg2kind(prb.alg);
dnnl_alg_kind_t f = activation2kind(prb.activation);

View File

@ -42,43 +42,43 @@ typedef enum { action_copy = 0, action_sum, action_concat } rnn_action_t;
dnnl_status_t init_rnn_fwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
const prb_t &prb, dnnl_prop_kind_t prop_kind,
const dnnl_memory_desc_t *src_layer_d,
const dnnl_memory_desc_t *src_iter_d,
const dnnl_memory_desc_t *src_iter_c_d,
const dnnl_memory_desc_t *attention_d,
const dnnl_memory_desc_t *weights_layer_d,
const dnnl_memory_desc_t *weights_iter_d,
const dnnl_memory_desc_t *weights_peephole_d,
const dnnl_memory_desc_t *weights_projection_d,
const dnnl_memory_desc_t *bias_d, const dnnl_memory_desc_t *dst_layer_d,
const dnnl_memory_desc_t *dst_iter_d,
const dnnl_memory_desc_t *dst_iter_c_d, dnnl_primitive_attr_t attr);
const_dnnl_memory_desc_t src_layer_d,
const_dnnl_memory_desc_t src_iter_d,
const_dnnl_memory_desc_t src_iter_c_d,
const_dnnl_memory_desc_t attention_d,
const_dnnl_memory_desc_t weights_layer_d,
const_dnnl_memory_desc_t weights_iter_d,
const_dnnl_memory_desc_t weights_peephole_d,
const_dnnl_memory_desc_t weights_projection_d,
const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
const_dnnl_memory_desc_t dst_iter_d,
const_dnnl_memory_desc_t dst_iter_c_d, dnnl_primitive_attr_t attr);
dnnl_status_t init_rnn_bwd_pd(dnnl_primitive_desc_t *pd, dnnl_engine_t engine,
const prb_t &prb, dnnl_prop_kind_t prop_kind,
const dnnl_memory_desc_t *src_layer_d,
const dnnl_memory_desc_t *src_iter_d,
const dnnl_memory_desc_t *src_iter_c_d,
const dnnl_memory_desc_t *attention_d,
const dnnl_memory_desc_t *weights_layer_d,
const dnnl_memory_desc_t *weights_iter_d,
const dnnl_memory_desc_t *weights_peephole_d,
const dnnl_memory_desc_t *weights_projection_d,
const dnnl_memory_desc_t *bias_d, const dnnl_memory_desc_t *dst_layer_d,
const dnnl_memory_desc_t *dst_iter_d,
const dnnl_memory_desc_t *dst_iter_c_d,
const dnnl_memory_desc_t *diff_src_layer_d,
const dnnl_memory_desc_t *diff_src_iter_d,
const dnnl_memory_desc_t *diff_src_iter_c_d,
const dnnl_memory_desc_t *diff_attention_d,
const dnnl_memory_desc_t *diff_weights_layer_d,
const dnnl_memory_desc_t *diff_weights_iter_d,
const dnnl_memory_desc_t *diff_weights_peephole_d,
const dnnl_memory_desc_t *diff_weights_projection_d,
const dnnl_memory_desc_t *diff_bias_d,
const dnnl_memory_desc_t *diff_dst_layer_d,
const dnnl_memory_desc_t *diff_dst_iter_d,
const dnnl_memory_desc_t *diff_dst_iter_c_d,
const_dnnl_memory_desc_t src_layer_d,
const_dnnl_memory_desc_t src_iter_d,
const_dnnl_memory_desc_t src_iter_c_d,
const_dnnl_memory_desc_t attention_d,
const_dnnl_memory_desc_t weights_layer_d,
const_dnnl_memory_desc_t weights_iter_d,
const_dnnl_memory_desc_t weights_peephole_d,
const_dnnl_memory_desc_t weights_projection_d,
const_dnnl_memory_desc_t bias_d, const_dnnl_memory_desc_t dst_layer_d,
const_dnnl_memory_desc_t dst_iter_d,
const_dnnl_memory_desc_t dst_iter_c_d,
const_dnnl_memory_desc_t diff_src_layer_d,
const_dnnl_memory_desc_t diff_src_iter_d,
const_dnnl_memory_desc_t diff_src_iter_c_d,
const_dnnl_memory_desc_t diff_attention_d,
const_dnnl_memory_desc_t diff_weights_layer_d,
const_dnnl_memory_desc_t diff_weights_iter_d,
const_dnnl_memory_desc_t diff_weights_peephole_d,
const_dnnl_memory_desc_t diff_weights_projection_d,
const_dnnl_memory_desc_t diff_bias_d,
const_dnnl_memory_desc_t diff_dst_layer_d,
const_dnnl_memory_desc_t diff_dst_iter_d,
const_dnnl_memory_desc_t diff_dst_iter_c_d,
const_dnnl_primitive_desc_t hint, dnnl_primitive_attr_t attr);
void init_buffer(float *buf, int64_t size, float value);