mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
14224 lines
642 KiB
C++
14224 lines
642 KiB
C++
/*******************************************************************************
|
|
* Copyright 2016-2025 Intel Corporation
|
|
* Copyright 2024-2025 FUJITSU LIMITED
|
|
* Copyright 2025 Arm Ltd. and affiliates
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
/// @file
|
|
/// C++ API
|
|
|
|
#ifndef ONEAPI_DNNL_DNNL_HPP
|
|
#define ONEAPI_DNNL_DNNL_HPP
|
|
// NOLINTBEGIN(readability-identifier-naming)
|
|
|
|
#include "oneapi/dnnl/dnnl_config.h"
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
#include <algorithm>
|
|
#include <cstdlib>
|
|
#include <iterator>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <unordered_map>
|
|
|
|
#include "oneapi/dnnl/dnnl.h"
|
|
#include "oneapi/dnnl/dnnl_common.hpp"
|
|
|
|
/// @endcond
|
|
|
|
/// @addtogroup dnnl_api oneDNN API
|
|
/// @{
|
|
|
|
/// oneDNN namespace
|
|
namespace dnnl {
|
|
|
|
/// @addtogroup dnnl_api_utils Utilities
|
|
/// Utility types and definitions.
|
|
/// @{
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
template <typename T>
|
|
void validate_container_size(const T &v, const char *error_message,
|
|
int min_size = 1, int max_size = -1) {
|
|
const int size = (int)v.size();
|
|
if (size < min_size || (max_size >= 0 && size > max_size))
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
|
|
}
|
|
/// @endcond
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
template <>
|
|
struct handle_traits<dnnl_memory_desc_t> {
|
|
static dnnl_status_t destructor(dnnl_memory_desc_t p) {
|
|
return dnnl_memory_desc_destroy(p);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct handle_traits<dnnl_memory_t> {
|
|
static dnnl_status_t destructor(dnnl_memory_t p) {
|
|
return dnnl_memory_destroy(p);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct handle_traits<dnnl_primitive_desc_t> {
|
|
static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
|
|
return dnnl_primitive_desc_destroy(p);
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct handle_traits<dnnl_primitive_t> {
|
|
static dnnl_status_t destructor(dnnl_primitive_t p) {
|
|
return dnnl_primitive_destroy(p);
|
|
}
|
|
};
|
|
|
|
/// @endcond
|
|
|
|
/// @} dnnl_api_utils
|
|
|
|
struct stream;
|
|
struct memory;
|
|
struct primitive_desc;
|
|
|
|
/// @addtogroup dnnl_api_primitives Primitives
|
|
/// Compute primitives
|
|
/// @sa @ref dev_guide_basic_concepts
|
|
/// @{
|
|
|
|
/// @addtogroup dnnl_api_primitives_common Common
|
|
/// Common operations to create, destroy and inspect primitives
|
|
/// @{
|
|
|
|
/// Base class for all computational primitives.
|
|
struct primitive : public handle<dnnl_primitive_t> {
|
|
/// Kinds of primitives supported by the library.
|
|
enum class kind {
|
|
/// Undefined primitive
|
|
undef = dnnl_undefined_primitive,
|
|
/// A reorder primitive.
|
|
reorder = dnnl_reorder,
|
|
/// A shuffle primitive.
|
|
shuffle = dnnl_shuffle,
|
|
/// A (out-of-place) tensor concatenation primitive.
|
|
concat = dnnl_concat,
|
|
/// A summation primitive.
|
|
sum = dnnl_sum,
|
|
/// A convolution primitive.
|
|
convolution = dnnl_convolution,
|
|
/// A deconvolution primitive.
|
|
deconvolution = dnnl_deconvolution,
|
|
/// An element-wise primitive.
|
|
eltwise = dnnl_eltwise,
|
|
/// An LRN primitive.
|
|
lrn = dnnl_lrn,
|
|
/// A batch normalization primitive.
|
|
batch_normalization = dnnl_batch_normalization,
|
|
/// An inner product primitive.
|
|
inner_product = dnnl_inner_product,
|
|
/// An RNN primitive.
|
|
rnn = dnnl_rnn,
|
|
/// A binary primitive.
|
|
binary = dnnl_binary,
|
|
/// A matmul (matrix multiplication) primitive.
|
|
matmul = dnnl_matmul,
|
|
/// A resampling primitive.
|
|
resampling = dnnl_resampling,
|
|
/// A pooling primitive.
|
|
pooling = dnnl_pooling,
|
|
/// A reduction primitive.
|
|
reduction = dnnl_reduction,
|
|
/// A PReLU primitive.
|
|
prelu = dnnl_prelu,
|
|
/// A softmax primitive.
|
|
softmax = dnnl_softmax,
|
|
/// A layer normalization primitive.
|
|
layer_normalization = dnnl_layer_normalization,
|
|
/// A group normalization primitive
|
|
group_normalization = dnnl_group_normalization,
|
|
};
|
|
|
|
using handle::handle;
|
|
|
|
/// Default constructor. Constructs an empty object.
|
|
primitive() = default;
|
|
|
|
/// Constructs a primitive from a C API primitive descriptor.
|
|
///
|
|
/// @param c_pd C API primitive descriptor.
|
|
primitive(const_dnnl_primitive_desc_t c_pd);
|
|
|
|
/// Constructs a primitive from a C API primitive descriptor and a cache blob.
|
|
///
|
|
/// @param c_pd C API primitive descriptor.
|
|
/// @param cache_blob Cache blob.
|
|
primitive(const_dnnl_primitive_desc_t c_pd,
|
|
const std::vector<uint8_t> &cache_blob);
|
|
|
|
/// Constructs a primitive from a primitive descriptor.
|
|
///
|
|
/// @param pd Primitive descriptor.
|
|
primitive(const primitive_desc &pd);
|
|
|
|
/// Constructs a primitive from a primitive descriptor and a cache blob.
|
|
///
|
|
/// @param pd Primitive descriptor.
|
|
/// @param cache_blob Cache blob.
|
|
primitive(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob);
|
|
|
|
/// Returns the C API primitive descriptor of the underlying C API
|
|
/// primitive.
|
|
///
|
|
/// @returns The underlying C API primitive descriptor.
|
|
inline const_dnnl_primitive_desc_t get_primitive_desc() const;
|
|
|
|
/// Returns the kind of the primitive.
|
|
///
|
|
/// @returns The primitive kind.
|
|
inline kind get_kind() const;
|
|
|
|
/// Returns a cache blob for the primitive.
|
|
///
|
|
/// @returns Vector containing the cache blob.
|
|
///
|
|
/// @note The cache blob can be empty. It's the user's responsibility to
|
|
/// check whether it's empty prior to passing it to the primitive
|
|
/// constructor.
|
|
inline std::vector<uint8_t> get_cache_blob() const;
|
|
|
|
/// Executes computations specified by the primitive in a specified stream.
|
|
///
|
|
/// Arguments are passed via an arguments map containing <index,
|
|
/// memory object> pairs. The index must be one of the `DNNL_ARG_*` values
|
|
/// such as `DNNL_ARG_SRC`, and the memory must have a memory descriptor
|
|
/// matching the one returned by
|
|
/// primitive_desc::query_md(#query::exec_arg_md, index) unless using
|
|
/// dynamic shapes (see #DNNL_RUNTIME_DIM_VAL).
|
|
///
|
|
/// @param astream Stream object. The stream must belong to the same engine
|
|
/// as the primitive.
|
|
/// @param args Arguments map.
|
|
void execute(const stream &astream,
|
|
const std::unordered_map<int, memory> &args) const;
|
|
};
|
|
|
|
/// Converts primitive kind enum value from C++ API to C API type.
|
|
///
|
|
/// @param akind C++ API primitive kind enum value.
|
|
/// @returns Corresponding C API primitive kind enum value.
|
|
inline dnnl_primitive_kind_t convert_to_c(primitive::kind akind) {
|
|
return static_cast<dnnl_primitive_kind_t>(akind);
|
|
}
|
|
|
|
const_dnnl_primitive_desc_t primitive::get_primitive_desc() const {
|
|
const_dnnl_primitive_desc_t pd;
|
|
error::wrap_c_api(dnnl_primitive_get_primitive_desc(get(), &pd),
|
|
"could not get a primitive descriptor from a primitive");
|
|
return pd;
|
|
}
|
|
|
|
dnnl::primitive::kind primitive::get_kind() const {
|
|
const_dnnl_primitive_desc_t pd = get_primitive_desc();
|
|
// TODO (Roma): the code below is only needed because get_primitive_desc
|
|
// returns a C type.
|
|
dnnl_primitive_kind_t kind;
|
|
error::wrap_c_api(dnnl_primitive_desc_query(
|
|
pd, dnnl_query_primitive_kind, 0, (void *)&kind),
|
|
"could not get a primitive kind from a primitive descriptor");
|
|
return static_cast<dnnl::primitive::kind>(kind);
|
|
}
|
|
|
|
std::vector<uint8_t> primitive::get_cache_blob() const {
|
|
size_t size;
|
|
error::wrap_c_api(dnnl_primitive_get_cache_blob(get(), &size, nullptr),
|
|
"could not get cache blob size from a primitive");
|
|
|
|
std::vector<uint8_t> cache_blob(size);
|
|
error::wrap_c_api(
|
|
dnnl_primitive_get_cache_blob(get(), &size, cache_blob.data()),
|
|
"could not get a cache blob from a primitive");
|
|
return cache_blob;
|
|
}
|
|
|
|
/// @} dnnl_api_primitives_common
|
|
|
|
/// @addtogroup dnnl_api_attributes
|
|
///
|
|
/// A container for parameters that extend primitives behavior.
|
|
///
|
|
/// Attributes can also contain Post-ops, which are computations executed
|
|
/// after the primitive.
|
|
///
|
|
/// @sa @ref dev_guide_attributes
|
|
/// @sa @ref dev_guide_attributes_post_ops
|
|
///
|
|
/// @{
|
|
|
|
/// Scratchpad mode
|
|
enum class scratchpad_mode {
|
|
/// The library manages the scratchpad allocation according to the policy
|
|
/// specified by the `DNNL_ENABLE_CONCURRENT_EXEC`
|
|
/// [build option](@ref dev_guide_build_options) (default).
|
|
///
|
|
/// When `DNNL_ENABLE_CONCURRENT_EXEC=OFF` (default), the library
|
|
/// scratchpad is common to all primitives to reduce the memory footprint.
|
|
/// This configuration comes with limited thread-safety properties, namely
|
|
/// primitives can be created and executed in parallel but cannot migrate
|
|
/// between threads (in other words, each primitive should be executed in
|
|
/// the same thread it was created in).
|
|
///
|
|
/// When `DNNL_ENABLE_CONCURRENT_EXEC=ON`, the library scratchpad is
|
|
/// private to each primitive. The memory footprint is larger than when
|
|
/// using `DNNL_ENABLE_CONCURRENT_EXEC=OFF` but different primitives can be
|
|
/// created and run concurrently (the same primitive cannot be run
|
|
/// concurrently from two different threads though).
|
|
library = dnnl_scratchpad_mode_library,
|
|
/// The user manages the scratchpad allocation by querying and providing
|
|
/// the scratchpad memory to primitives. This mode is thread-safe as long
|
|
/// as the scratchpad buffers are not used concurrently by two primitive
|
|
/// executions.
|
|
user = dnnl_scratchpad_mode_user,
|
|
};
|
|
|
|
/// Converts a scratchpad mode enum value from C++ API to C API type.
|
|
///
|
|
/// @param mode C++ API scratchpad mode enum value.
|
|
/// @returns Corresponding C API scratchpad mode enum value.
|
|
inline dnnl_scratchpad_mode_t convert_to_c(scratchpad_mode mode) {
|
|
return static_cast<dnnl_scratchpad_mode_t>(mode);
|
|
}
|
|
|
|
/// Rounding mode
|
|
enum class rounding_mode {
|
|
/// rounding mode dictated by the floating-point environment
|
|
environment = dnnl_rounding_mode_environment,
|
|
/// stochastic rounding mode where a random bias is added to the
|
|
/// trailing mantissa bits before conversion.
|
|
stochastic = dnnl_rounding_mode_stochastic
|
|
};
|
|
|
|
/// Converts a rounding mode enum value from C++ API to C API type.
|
|
///
|
|
/// @param mode C++ API rounding mode enum value.
|
|
/// @returns Corresponding C API rounding mode enum value.
|
|
inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
|
|
return static_cast<dnnl_rounding_mode_t>(mode);
|
|
}
|
|
|
|
/// Quantization kind
|
|
enum class quantization_mode {
|
|
/// used for unspecified quantization kind
|
|
undef = dnnl_quantization_mode_undef,
|
|
/// static quantization mode: quantization parameter is computed
|
|
/// ahead of time and passed to oneDNN as an input.
|
|
static_sazp = dnnl_quantization_mode_static_sazp,
|
|
/// dynamic quantization mode following OCP MX spec: quantization
|
|
/// parameter is computed by oneDNN following the OCP MX spec
|
|
/// formula and written as an output.
|
|
dynamic_mx = dnnl_quantization_mode_dynamic_mx,
|
|
};
|
|
|
|
/// Converts a quantization kind enum value from C++ API to C API type.
|
|
///
|
|
/// @param qmode C++ API quantization kind enum value.
|
|
/// @returns Corresponding C API quantization kind enum value.
|
|
inline dnnl_quantization_mode_t convert_to_c(quantization_mode qmode) {
|
|
return static_cast<dnnl_quantization_mode_t>(qmode);
|
|
}
|
|
|
|
/// Propagation kind.
|
|
enum class prop_kind {
|
|
/// Undefined propagation kind.
|
|
undef = dnnl_prop_kind_undef,
|
|
/// Forward data propagation (training mode). In this mode, primitives
|
|
/// perform computations necessary for subsequent backward propagation.
|
|
forward_training = dnnl_forward_training,
|
|
/// Forward data propagation (inference mode). In this mode, primitives
|
|
/// perform only computations that are necessary for inference and omit
|
|
/// computations that are necessary only for backward propagation.
|
|
forward_inference = dnnl_forward_inference,
|
|
/// Forward data propagation,
|
|
/// alias for #dnnl::prop_kind::forward_training.
|
|
forward = dnnl_forward,
|
|
/// Backward propagation (with respect to all parameters).
|
|
backward = dnnl_backward,
|
|
/// Backward data propagation.
|
|
backward_data = dnnl_backward_data,
|
|
/// Backward weights propagation.
|
|
backward_weights = dnnl_backward_weights,
|
|
/// Backward bias propagation.
|
|
backward_bias = dnnl_backward_bias
|
|
};
|
|
|
|
/// Converts propagation kind enum value from C++ API to C API type.
|
|
///
|
|
/// @param akind C++ API propagation kind enum value.
|
|
/// @returns Corresponding C API propagation kind enum value.
|
|
inline dnnl_prop_kind_t convert_to_c(prop_kind akind) {
|
|
return static_cast<dnnl_prop_kind_t>(akind);
|
|
}
|
|
|
|
/// Kinds of algorithms.
|
|
enum class algorithm {
|
|
/// Undefined algorithm
|
|
undef = dnnl_alg_kind_undef,
|
|
/// Convolution algorithm that is chosen to be either direct or Winograd
|
|
/// automatically
|
|
convolution_auto = dnnl_convolution_auto,
|
|
/// Direct convolution
|
|
convolution_direct = dnnl_convolution_direct,
|
|
/// Winograd convolution
|
|
convolution_winograd = dnnl_convolution_winograd,
|
|
/// Direct deconvolution
|
|
deconvolution_direct = dnnl_deconvolution_direct,
|
|
/// Winograd deconvolution
|
|
deconvolution_winograd = dnnl_deconvolution_winograd,
|
|
/// Elementwise: rectified linear unit (ReLU)
|
|
eltwise_relu = dnnl_eltwise_relu,
|
|
/// Elementwise: hyperbolic tangent non-linearity (tanh)
|
|
eltwise_tanh = dnnl_eltwise_tanh,
|
|
/// Elementwise: exponential linear unit (ELU)
|
|
eltwise_elu = dnnl_eltwise_elu,
|
|
/// Elementwise: square
|
|
eltwise_square = dnnl_eltwise_square,
|
|
/// Elementwise: abs
|
|
eltwise_abs = dnnl_eltwise_abs,
|
|
/// Elementwise: square root
|
|
eltwise_sqrt = dnnl_eltwise_sqrt,
|
|
/// Elementwise: swish (\f$x \cdot sigmoid(a \cdot x)\f$)
|
|
eltwise_swish = dnnl_eltwise_swish,
|
|
/// Elementwise: linear
|
|
eltwise_linear = dnnl_eltwise_linear,
|
|
/// Elementwise: soft_relu
|
|
eltwise_soft_relu = dnnl_eltwise_soft_relu,
|
|
/// Elementwise: mish
|
|
eltwise_mish = dnnl_eltwise_mish,
|
|
/// Elementwise: logistic
|
|
eltwise_logistic = dnnl_eltwise_logistic,
|
|
/// Elementwise: exponent
|
|
eltwise_exp = dnnl_eltwise_exp,
|
|
/// Elementwise: tanh-based gelu
|
|
eltwise_gelu_tanh = dnnl_eltwise_gelu_tanh,
|
|
/// Elementwise: erf-based gelu
|
|
eltwise_gelu_erf = dnnl_eltwise_gelu_erf,
|
|
/// Elementwise: natural logarithm
|
|
eltwise_log = dnnl_eltwise_log,
|
|
/// Elementwise: clip
|
|
eltwise_clip = dnnl_eltwise_clip,
|
|
/// Eltwise: clip version 2
|
|
eltwise_clip_v2 = dnnl_eltwise_clip_v2,
|
|
/// Elementwise: pow
|
|
eltwise_pow = dnnl_eltwise_pow,
|
|
/// Elementwise: round
|
|
eltwise_round = dnnl_eltwise_round,
|
|
/// Elementwise: hardswish
|
|
eltwise_hardswish = dnnl_eltwise_hardswish,
|
|
/// Elementwise: hardsigmoid
|
|
eltwise_hardsigmoid = dnnl_eltwise_hardsigmoid,
|
|
/// Elementwise: rectified linar unit (ReLU) (dst for backward)
|
|
eltwise_relu_use_dst_for_bwd = dnnl_eltwise_relu_use_dst_for_bwd,
|
|
/// Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
|
|
eltwise_tanh_use_dst_for_bwd = dnnl_eltwise_tanh_use_dst_for_bwd,
|
|
/// Elementwise: exponential linear unit (ELU) (dst for backward)
|
|
eltwise_elu_use_dst_for_bwd = dnnl_eltwise_elu_use_dst_for_bwd,
|
|
/// Elementwise: square root (dst for backward)
|
|
eltwise_sqrt_use_dst_for_bwd = dnnl_eltwise_sqrt_use_dst_for_bwd,
|
|
/// Elementwise: logistic (dst for backward)
|
|
eltwise_logistic_use_dst_for_bwd = dnnl_eltwise_logistic_use_dst_for_bwd,
|
|
/// Elementwise: exponent (dst for backward)
|
|
eltwise_exp_use_dst_for_bwd = dnnl_eltwise_exp_use_dst_for_bwd,
|
|
/// Elementwise: clip version 2 (dst for backward)
|
|
eltwise_clip_v2_use_dst_for_bwd = dnnl_eltwise_clip_v2_use_dst_for_bwd,
|
|
/// Local response normalization (LRN) across multiple channels
|
|
lrn_across_channels = dnnl_lrn_across_channels,
|
|
/// LRN within a single channel
|
|
lrn_within_channel = dnnl_lrn_within_channel,
|
|
/// Max pooling
|
|
pooling_max = dnnl_pooling_max,
|
|
/// Average pooling include padding
|
|
pooling_avg_include_padding = dnnl_pooling_avg_include_padding,
|
|
/// Average pooling exclude padding
|
|
pooling_avg_exclude_padding = dnnl_pooling_avg_exclude_padding,
|
|
/// RNN cell
|
|
vanilla_rnn = dnnl_vanilla_rnn,
|
|
/// LSTM cell
|
|
vanilla_lstm = dnnl_vanilla_lstm,
|
|
/// GRU cell
|
|
vanilla_gru = dnnl_vanilla_gru,
|
|
/// GRU cell with linear before reset. Differs from the vanilla GRU
|
|
/// in how the new memory gate is calculated:
|
|
/// \f$c_t = tanh(W_c*x_t + b_{c_x} + r_t*(U_c*h_{t-1}+b_{c_h})) \f$
|
|
/// LRB GRU expects 4 bias tensors on input:
|
|
/// \f$[b_{u}, b_{r}, b_{c_x}, b_{c_h}]\f$
|
|
lbr_gru = dnnl_lbr_gru,
|
|
/// AUGRU cell
|
|
vanilla_augru = dnnl_vanilla_augru,
|
|
/// AUGRU cell with linear before reset
|
|
lbr_augru = dnnl_lbr_augru,
|
|
/// Binary add
|
|
binary_add = dnnl_binary_add,
|
|
/// Binary mul
|
|
binary_mul = dnnl_binary_mul,
|
|
/// Binary max
|
|
binary_max = dnnl_binary_max,
|
|
/// Binary min
|
|
binary_min = dnnl_binary_min,
|
|
/// Binary div
|
|
binary_div = dnnl_binary_div,
|
|
/// Binary sub
|
|
binary_sub = dnnl_binary_sub,
|
|
/// Binary greater than or equal
|
|
binary_ge = dnnl_binary_ge,
|
|
/// Binary greater than
|
|
binary_gt = dnnl_binary_gt,
|
|
/// Binary less than or equal
|
|
binary_le = dnnl_binary_le,
|
|
/// Binary less than
|
|
binary_lt = dnnl_binary_lt,
|
|
/// Binary equal
|
|
binary_eq = dnnl_binary_eq,
|
|
/// Binary not equal
|
|
binary_ne = dnnl_binary_ne,
|
|
/// Binary select
|
|
binary_select = dnnl_binary_select,
|
|
/// Nearest Neighbor resampling method
|
|
resampling_nearest = dnnl_resampling_nearest,
|
|
/// Linear (Bilinear, Trilinear) resampling method
|
|
resampling_linear = dnnl_resampling_linear,
|
|
/// Reduction using max operation
|
|
reduction_max = dnnl_reduction_max,
|
|
/// Reduction using min operation
|
|
reduction_min = dnnl_reduction_min,
|
|
/// Reduction using sum operation
|
|
reduction_sum = dnnl_reduction_sum,
|
|
/// Reduction using mul operation
|
|
reduction_mul = dnnl_reduction_mul,
|
|
/// Reduction using mean operation
|
|
reduction_mean = dnnl_reduction_mean,
|
|
/// Reduction using norm_lp_max operation
|
|
reduction_norm_lp_max = dnnl_reduction_norm_lp_max,
|
|
/// Reduction using norm_lp_sum operation
|
|
reduction_norm_lp_sum = dnnl_reduction_norm_lp_sum,
|
|
/// Reduction using norm_lp_power_p_max operation
|
|
reduction_norm_lp_power_p_max = dnnl_reduction_norm_lp_power_p_max,
|
|
/// Reduction using norm_lp_power_p_sum operation
|
|
reduction_norm_lp_power_p_sum = dnnl_reduction_norm_lp_power_p_sum,
|
|
/// Softmax, numerically stable
|
|
softmax_accurate = dnnl_softmax_accurate,
|
|
/// LogSoftmax, numerically stable
|
|
softmax_log = dnnl_softmax_log,
|
|
};
|
|
|
|
/// Converts algorithm kind enum value from C++ API to C API type.
|
|
/// @param aalgorithm C++ API algorithm kind enum value.
|
|
/// @returns Corresponding C API algorithm kind enum value.
|
|
inline dnnl_alg_kind_t convert_to_c(algorithm aalgorithm) {
|
|
return static_cast<dnnl_alg_kind_t>(aalgorithm);
|
|
}
|
|
|
|
/// @} dnnl_api_attributes
|
|
|
|
/// @addtogroup dnnl_api_primitives_common
|
|
/// @{
|
|
|
|
/// Flags for normalization primitives.
|
|
enum class normalization_flags : unsigned {
|
|
/// Use no normalization flags. If specified, the library computes mean and
|
|
/// variance on forward propagation for training and inference, outputs
|
|
/// them on forward propagation for training, and computes the respective
|
|
/// derivatives on backward propagation.
|
|
///
|
|
/// @note
|
|
/// Backward propagation of type #dnnl::prop_kind::backward_data has
|
|
/// the same behavior as #dnnl::prop_kind::backward.
|
|
none = dnnl_normalization_flags_none,
|
|
|
|
/// Use global statistics. If specified, the library uses mean and
|
|
/// variance provided by the user as an input on forward propagation and
|
|
/// does not compute their derivatives on backward propagation. Otherwise,
|
|
/// the library computes mean and variance on forward propagation for
|
|
/// training and inference, outputs them on forward propagation for
|
|
/// training, and computes the respective derivatives on backward
|
|
/// propagation.
|
|
use_global_stats = dnnl_use_global_stats,
|
|
|
|
/// Use scale parameter. If specified, the user is expected to pass scale as
|
|
/// input on forward propagation. On backward propagation of type
|
|
/// #dnnl::prop_kind::backward, the library computes its derivative.
|
|
use_scale = dnnl_use_scale,
|
|
|
|
/// Use shift parameter. If specified, the user is expected to pass shift as
|
|
/// input on forward propagation. On backward propagation of type
|
|
/// #dnnl::prop_kind::backward, the library computes its derivative.
|
|
use_shift = dnnl_use_shift,
|
|
|
|
/// Fuse normalization with ReLU. On training, normalization will require
|
|
/// the workspace to implement backward propagation. On inference, the
|
|
/// workspace is not required and behavior is the same as when normalization
|
|
/// is fused with ReLU using the post-ops API.
|
|
///
|
|
/// @note
|
|
/// The flag implies negative slope being 0. On training this is the only
|
|
/// configuration supported. For inference, to use non-zero negative slope
|
|
/// consider using @ref dev_guide_attributes_post_ops.
|
|
fuse_norm_relu = dnnl_fuse_norm_relu,
|
|
|
|
/// Fuse normalization with an elementwise binary Add operation
|
|
/// followed by ReLU.
|
|
/// During training, normalization will require a workspace to implement
|
|
/// backward propagation. For inference, the workspace is not needed.
|
|
/// On forward propagation, an elementwise binary Add operation is applied
|
|
/// to the normalization results with an additional input tensor, followed
|
|
/// by ReLU with a negative slope of 0.
|
|
/// On backward propagation, the result of the backward ReLU operation
|
|
/// with the input tensor and workspace from the forward pass is saved
|
|
/// to an extra output tensor, and backward normalization is performed.
|
|
fuse_norm_add_relu = dnnl_fuse_norm_add_relu,
|
|
|
|
/// Use Root Mean Square (RMS) Normalization. In forward propagation,
|
|
/// the mean is considered zero, and RMS norm is used instead of variance
|
|
/// for scaling. Only the RMS norm is output during forward propagation for
|
|
/// training. In backward propagation, the library calculates the derivative
|
|
/// with respect to the RMS norm only, assuming the mean is zero.
|
|
///
|
|
/// @note
|
|
/// When used with #dnnl::normalization_flags::use_global_stats,
|
|
/// only RMS norm is required to be provided as input.
|
|
rms_norm = dnnl_rms_norm,
|
|
};
|
|
|
|
/// Converts normalization flags enum value from C++ API to C API type.
|
|
/// @param flags C++ API normalization flags enum value.
|
|
/// @returns Corresponding C API normalization flags enum value.
|
|
inline dnnl_normalization_flags_t convert_to_c(normalization_flags flags) {
|
|
return static_cast<dnnl_normalization_flags_t>(flags);
|
|
}
|
|
|
|
/// @} dnnl_api_primitives_common
|
|
|
|
/// @addtogroup dnnl_api_rnn
|
|
/// @{
|
|
|
|
/// RNN cell flags.
|
|
enum class rnn_flags : unsigned {
|
|
/// Undefined RNN flags
|
|
undef = dnnl_rnn_flags_undef,
|
|
/// Do not add weights gradient to existing diff_weights memory
|
|
diff_weights_overwrite = dnnl_rnn_flags_diff_weights_overwrite,
|
|
};
|
|
|
|
/// Converts RNN cell flags enum value from C++ API to C API type.
|
|
/// @param flags C++ API RNN cell flags enum value.
|
|
/// @returns Corresponding C API RNN cell flags enum value.
|
|
inline dnnl_rnn_flags_t convert_to_c(rnn_flags flags) {
|
|
return static_cast<dnnl_rnn_flags_t>(flags);
|
|
}
|
|
|
|
DNNL_DEFINE_BITMASK_OPS(normalization_flags)
|
|
DNNL_DEFINE_BITMASK_OPS(rnn_flags)
|
|
|
|
/// A direction of RNN primitive execution
|
|
enum class rnn_direction {
|
|
/// Undefined RNN direction.
|
|
undef = dnnl_rnn_direction_undef,
|
|
/// Unidirectional execution of RNN primitive from left to right.
|
|
unidirectional_left2right = dnnl_unidirectional_left2right,
|
|
/// Unidirectional execution of RNN primitive from right to left.
|
|
unidirectional_right2left = dnnl_unidirectional_right2left,
|
|
/// Bidirectional execution of RNN primitive with concatenation of the
|
|
/// results.
|
|
bidirectional_concat = dnnl_bidirectional_concat,
|
|
/// Bidirectional execution of RNN primitive with summation of the
|
|
/// results.
|
|
bidirectional_sum = dnnl_bidirectional_sum,
|
|
};
|
|
|
|
/// Converts RNN direction enum value from C++ API to C API type.
|
|
/// @param dir C++ API RNN direction enum value.
|
|
/// @returns Corresponding C API RNN direction enum value.
|
|
inline dnnl_rnn_direction_t convert_to_c(rnn_direction dir) {
|
|
return static_cast<dnnl_rnn_direction_t>(dir);
|
|
}
|
|
|
|
/// @} dnnl_api_rnn
|
|
|
|
/// @addtogroup dnnl_api_primitives_common
|
|
/// @{
|
|
|
|
/// Primitive descriptor query specification.
|
|
///
|
|
/// In general, queries are not used with the C++ API because most queries are
|
|
/// implemented as class members.
|
|
///
|
|
/// See @ref dnnl_query_t for more information.
|
|
enum class query {
|
|
/// no query
|
|
undef = dnnl_query_undef,
|
|
|
|
/// execution engine
|
|
engine = dnnl_query_engine,
|
|
/// primitive kind
|
|
primitive_kind = dnnl_query_primitive_kind,
|
|
|
|
/// number of inputs expected
|
|
num_of_inputs_s32 = dnnl_query_num_of_inputs_s32,
|
|
/// number of outputs expected
|
|
num_of_outputs_s32 = dnnl_query_num_of_outputs_s32,
|
|
|
|
/// runtime estimation (seconds), unimplemented
|
|
time_estimate_f64 = dnnl_query_time_estimate_f64,
|
|
/// memory required for scratchpad (bytes)
|
|
///
|
|
/// @sa @ref dev_guide_attributes_scratchpad
|
|
memory_consumption_s64 = dnnl_query_memory_consumption_s64,
|
|
|
|
/// scratchpad engine
|
|
///
|
|
/// engine to be used for creating scratchpad memory
|
|
scratchpad_engine = dnnl_query_scratchpad_engine,
|
|
|
|
/// reorder source engine
|
|
reorder_src_engine = dnnl_query_reorder_src_engine,
|
|
/// reorder destination engine
|
|
reorder_dst_engine = dnnl_query_reorder_dst_engine,
|
|
|
|
/// implementation name
|
|
impl_info_str = dnnl_query_impl_info_str,
|
|
|
|
/// propagation kind
|
|
prop_kind = dnnl_query_prop_kind,
|
|
|
|
/// size of cache blob ID in bytes
|
|
cache_blob_id_size_s64 = dnnl_query_cache_blob_id_size_s64,
|
|
|
|
/// cache blob ID (pointer to array)
|
|
cache_blob_id = dnnl_query_cache_blob_id,
|
|
|
|
/// strides
|
|
strides = dnnl_query_strides,
|
|
/// dilations
|
|
dilations = dnnl_query_dilations,
|
|
/// left padding
|
|
padding_l = dnnl_query_padding_l,
|
|
/// right padding
|
|
padding_r = dnnl_query_padding_r,
|
|
/// epsilon
|
|
epsilon_f32 = dnnl_query_epsilon_f32,
|
|
/// flags
|
|
flags = dnnl_query_flags,
|
|
/// algorithm kind
|
|
alg_kind = dnnl_query_alg_kind,
|
|
/// alpha
|
|
alpha_f32 = dnnl_query_alpha_f32,
|
|
/// beta
|
|
beta_f32 = dnnl_query_beta_f32,
|
|
/// axis
|
|
axis_s32 = dnnl_query_axis_s32,
|
|
/// LRN parameter local size
|
|
local_size_s64 = dnnl_query_local_size_s64,
|
|
/// LRN parameter K
|
|
k_f32 = dnnl_query_k_f32,
|
|
/// Reduction parameter P
|
|
p_f32 = dnnl_query_p_f32,
|
|
/// Resampling parameter factors
|
|
factors = dnnl_query_factors,
|
|
/// RNN parameter cell kind
|
|
cell_kind = dnnl_query_cell_kind,
|
|
/// RNN parameter direction
|
|
direction = dnnl_query_direction,
|
|
/// RNN parameter activation kind
|
|
activation_kind = dnnl_query_activation_kind,
|
|
/// Pooling parameter kernel
|
|
kernel = dnnl_query_kernel,
|
|
/// Shuffle parameter group size
|
|
group_size_s64 = dnnl_query_group_size_s64,
|
|
|
|
/// source memory desc
|
|
src_md = dnnl_query_src_md,
|
|
/// source gradient (diff) memory desc
|
|
diff_src_md = dnnl_query_diff_src_md,
|
|
/// weights memory descriptor desc
|
|
weights_md = dnnl_query_weights_md,
|
|
/// weights gradient (diff) memory desc
|
|
diff_weights_md = dnnl_query_diff_weights_md,
|
|
/// destination memory desc
|
|
dst_md = dnnl_query_dst_md,
|
|
/// destination gradient (diff) memory desc
|
|
diff_dst_md = dnnl_query_diff_dst_md,
|
|
/// workspace memory desc
|
|
workspace_md = dnnl_query_workspace_md,
|
|
/// scratchpad memory desc
|
|
scratchpad_md = dnnl_query_scratchpad_md,
|
|
/// memory desc of an execute argument
|
|
exec_arg_md = dnnl_query_exec_arg_md,
|
|
|
|
/// number of dimensions
|
|
ndims_s32 = dnnl_query_ndims_s32,
|
|
/// vector of dimensions
|
|
dims = dnnl_query_dims,
|
|
/// data type
|
|
data_type = dnnl_query_data_type,
|
|
/// submemory offset
|
|
submemory_offset_s64 = dnnl_query_submemory_offset_s64,
|
|
/// vector of padded dimensions
|
|
padded_dims = dnnl_query_padded_dims,
|
|
/// vector of padded offsets
|
|
padded_offsets = dnnl_query_padded_offsets,
|
|
/// format kind
|
|
format_kind = dnnl_query_format_kind,
|
|
/// number of innermost blocks
|
|
inner_nblks_s32 = dnnl_query_inner_nblks_s32,
|
|
/// vector of sizes of the innermost blocks
|
|
inner_blks = dnnl_query_inner_blks,
|
|
/// vector of logical indices of the blocks
|
|
inner_idxs = dnnl_query_inner_idxs,
|
|
/// Sparse encoding
|
|
sparse_encoding = dnnl_query_sparse_encoding,
|
|
/// Number of non-zero entries
|
|
nnz_s64 = dnnl_query_nnz_s64,
|
|
/// Number of buffers required for a memory descriptor
|
|
num_handles_s32 = dnnl_query_num_handles_s32,
|
|
};
|
|
|
|
/// Converts query enum value from C++ API to C API type.
|
|
/// @param aquery C++ API query enum value.
|
|
/// @returns Corresponding C API query enum value.
|
|
inline dnnl_query_t convert_to_c(query aquery) {
|
|
return static_cast<dnnl_query_t>(aquery);
|
|
}
|
|
|
|
/// @} dnnl_api_primitives_common
|
|
|
|
/// @} dnnl_api_primitives
|
|
|
|
/// @addtogroup dnnl_api_memory Memory
|
|
///
|
|
/// A container that describes and stores data. Memory objects can contain
|
|
/// data of various types and formats. There are two levels of abstraction:
|
|
///
|
|
/// 1. **Memory descriptor** -- engine-agnostic logical description of data
|
|
/// (number of dimensions, dimension sizes, and data type), and,
|
|
/// optionally, the information about the physical format of data in
|
|
/// memory. If this information is not known yet, a memory descriptor can
|
|
/// be created with #dnnl::memory::format_tag::any. This allows
|
|
/// compute-intensive primitives to choose the best format for
|
|
/// computation. The user is responsible for reordering the data into the
|
|
/// chosen format when formats do not match.
|
|
///
|
|
/// A memory descriptor can be initialized either by specifying dimensions
|
|
/// and a memory format tag or strides for each of them, or by
|
|
/// manipulating the dnnl_memory_desc_t structure directly.
|
|
///
|
|
/// @warning
|
|
/// The latter approach requires understanding how the physical data
|
|
/// representation is mapped to the structure and is discouraged. This
|
|
/// topic is discussed in @ref dev_guide_understanding_memory_formats.
|
|
///
|
|
/// The user can query the amount of memory required by a memory
|
|
/// descriptor using the #dnnl::memory::desc::get_size() function. The
|
|
/// size of data in general cannot be computed as the product of
|
|
/// dimensions multiplied by the size of the data type. So users are
|
|
/// required to use this function for better code portability.
|
|
///
|
|
/// Two memory descriptors can be compared using the equality and
|
|
/// inequality operators. The comparison is especially useful when
|
|
/// checking whether it is necessary to reorder data from the user's data
|
|
/// format to a primitive's format.
|
|
///
|
|
/// 2. **Memory object** -- an engine-specific object that handles the memory
|
|
/// buffer and its description (a memory descriptor). For the CPU engine or
|
|
/// with USM, the memory buffer handle is simply a pointer to @c void. The
|
|
/// memory buffer can be queried using #dnnl::memory::get_data_handle() and
|
|
/// set using #dnnl::memory::set_data_handle(). The underlying SYCL buffer,
|
|
/// when used, can be queried using #dnnl::sycl_interop::get_buffer and set
|
|
/// using #dnnl::sycl_interop::set_buffer. A memory object can also be
|
|
/// queried for the underlying memory descriptor and for its engine using
|
|
/// #dnnl::memory::get_desc() and dnnl::memory::get_engine().
|
|
///
|
|
/// Along with ordinary memory descriptors with all dimensions being positive,
|
|
/// the library supports *zero-volume* memory descriptors with one or more
|
|
/// dimensions set to zero. This is used to support the NumPy\* convention.
|
|
/// If a zero-volume memory is passed to a primitive, the primitive typically
|
|
/// does not perform any computations with this memory. For example:
|
|
///
|
|
/// - A concatenation primitive would ignore all memory object with zeroes in
|
|
/// the concat dimension / axis.
|
|
///
|
|
/// - A forward convolution with a source memory object with zero in the
|
|
/// minibatch dimension would always produce a destination memory object
|
|
/// with a zero in the minibatch dimension and perform no computations.
|
|
///
|
|
/// - However, a forward convolution with a zero in one of the weights
|
|
/// dimensions is ill-defined and is considered to be an error by the
|
|
/// library because there is no clear definition of what the output values
|
|
/// should be.
|
|
///
|
|
/// Memory buffer of a zero-volume memory is never accessed.
|
|
///
|
|
/// @{
|
|
|
|
/// Memory object.
|
|
///
|
|
/// A memory object encapsulates a handle to a memory buffer allocated on a
|
|
/// specific engine, tensor dimensions, data type, and memory format, which is
|
|
/// the way tensor indices map to offsets in linear memory space. Memory
|
|
/// objects are passed to primitives during execution.
|
|
struct memory : public handle<dnnl_memory_t> {
|
|
using handle::handle;
|
|
|
|
/// Integer type for representing dimension sizes and indices.
|
|
using dim = dnnl_dim_t;
|
|
/// Vector of dimensions. Implementations are free to force a limit on the
|
|
/// vector's length.
|
|
using dims = std::vector<dim>;
|
|
|
|
/// Helper function that validates that an `std::vector` of dimensions can
|
|
/// be safely converted to the C API array ::dnnl_dims_t. Throws if
|
|
/// validation fails.
|
|
///
|
|
/// @param v Vector of dimensions.
|
|
/// @param min_size Minimum expected size of the vector.
|
|
template <typename T>
|
|
static void validate_dims(const std::vector<T> &v, int min_size = 0) {
|
|
validate_container_size(
|
|
v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
|
|
}
|
|
|
|
/// Data type specification.
|
|
enum class data_type {
|
|
/// Undefined data type (used for empty memory descriptors).
|
|
undef = dnnl_data_type_undef,
|
|
/// 4-bit float data type with 3-bit exponent and 0 bit mantissa.
|
|
f4_e3m0 = dnnl_f4_e3m0,
|
|
/// [MX-compliant 4-bit float data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 2-bit exponent and 1 bit mantissa.
|
|
f4_e2m1 = dnnl_f4_e2m1,
|
|
/// [MX-compliant 8-bit compliant scale data type](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) with 8-bit exponent.
|
|
e8m0 = dnnl_e8m0,
|
|
/// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
|
|
/// with a 5-bit exponent and a 2-bit mantissa.
|
|
f8_e5m2 = dnnl_f8_e5m2,
|
|
/// [OFP8 standard 8-bit floating-point](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-06-20-pdf)
|
|
/// with a 4-bit exponent and a 3-bit mantissa.
|
|
f8_e4m3 = dnnl_f8_e4m3,
|
|
/// [16-bit/half-precision floating point](https://en.wikipedia.org/wiki/Half-precision_floating-point_format).
|
|
f16 = dnnl_f16,
|
|
/// non-standard
|
|
/// [16-bit floating point with 7-bit mantissa](https://en.wikipedia.org/wiki/Bfloat16_floating-point_format).
|
|
bf16 = dnnl_bf16,
|
|
/// [32-bit/single-precision floating point](https://en.wikipedia.org/wiki/Single-precision_floating-point_format).
|
|
f32 = dnnl_f32,
|
|
//// [64-bit/double-precision floating point](https://en.wikipedia.org/wiki/Double-precision_floating-point_format).
|
|
f64 = dnnl_f64,
|
|
/// 32-bit signed integer.
|
|
s32 = dnnl_s32,
|
|
/// 8-bit signed integer.
|
|
s8 = dnnl_s8,
|
|
/// 8-bit unsigned integer.
|
|
u8 = dnnl_u8,
|
|
/// 4-bit signed integer.
|
|
s4 = dnnl_s4,
|
|
/// 4-bit unsigned integer.
|
|
u4 = dnnl_u4,
|
|
};
|
|
|
|
/// Returns size of data type in bytes.
|
|
/// @returns The number of bytes occupied by data type.
|
|
static size_t data_type_size(data_type adata_type) {
|
|
return dnnl_data_type_size(convert_to_c(adata_type));
|
|
}
|
|
|
|
/// Memory format kind
|
|
enum class format_kind {
|
|
/// Undefined memory format kind, used for empty memory descriptors.
|
|
undef = dnnl_format_kind_undef,
|
|
/// A special format kind that indicates that the actual format will be
|
|
/// selected by a primitive automatically.
|
|
any = dnnl_format_kind_any,
|
|
/// A tensor in a generic format described by the stride and blocking
|
|
/// values in each dimension.
|
|
blocked = dnnl_blocked,
|
|
/// Format kind for sparse tensors.
|
|
sparse = dnnl_format_kind_sparse,
|
|
/// Format kind for host scalars.
|
|
host_scalar = dnnl_format_kind_host_scalar,
|
|
/// A special format kind that indicates that tensor format is opaque.
|
|
opaque = dnnl_format_kind_opaque,
|
|
};
|
|
|
|
/// Sparse encodings.
|
|
/// @sa @ref dev_guide_sparsity
|
|
enum class sparse_encoding {
|
|
/// Undefined sparse encoding kind, used for empty memory descriptors.
|
|
undef = dnnl_sparse_encoding_undef,
|
|
/// Compressed Sparse Row (CSR) encoding.
|
|
csr = dnnl_csr,
|
|
/// An encoding that is used for an opaque storage schema for
|
|
/// tensors with unstructured sparsity. A memory descriptor with the
|
|
/// packed encoding cannot be used to create a memory object. It can
|
|
/// only be used to create a primitive descriptor to query the
|
|
/// actual memory descriptor (similar to the format tag `any`).
|
|
packed = dnnl_packed,
|
|
/// Coordinate Sparse (COO) encoding.
|
|
coo = dnnl_coo,
|
|
};
|
|
|
|
/// Memory format tag specification.
|
|
///
|
|
/// Memory format tags can be further divided into two categories:
|
|
///
|
|
/// - Domain-agnostic names, i.e. names that do not depend on the tensor
|
|
/// usage in the specific primitive. These names use letters from `a`
|
|
/// to `f` to denote logical dimensions and form the order in which the
|
|
/// dimensions are laid in memory. For example,
|
|
/// #dnnl::memory::format_tag::ab is used to denote a 2D tensor where the
|
|
/// second logical dimension (denoted as `b`) is the innermost, i.e.
|
|
/// has stride = 1, and the first logical dimension (`a`) is laid out in
|
|
/// memory with stride equal to the size of the second dimension. On the
|
|
/// other hand, #dnnl::memory::format_tag::ba is the transposed version
|
|
/// of the same tensor: the outermost dimension (`a`) becomes the
|
|
/// innermost one.
|
|
///
|
|
/// - Domain-specific names, i.e. names that make sense only in the
|
|
/// context of a certain domain, such as CNN. These names are
|
|
/// aliases to the corresponding domain-agnostic tags and used mostly
|
|
/// for convenience. For example, #dnnl::memory::format_tag::nc
|
|
/// is used to denote 2D CNN activations tensor memory format, where
|
|
/// the channels dimension is the innermost one and the batch dimension
|
|
/// is the outermost one. Moreover, #dnnl::memory::format_tag::nc is
|
|
/// an alias for #dnnl::memory::format_tag::ab, because for
|
|
/// CNN primitives the logical dimensions of activations tensors come
|
|
/// in order: batch, channels, spatial. In other words, batch
|
|
/// corresponds to the first logical dimension (`a`), and channels
|
|
/// correspond to the second one (`b`).
|
|
///
|
|
/// The following domain-specific notation applies to memory format tags:
|
|
/// - @c 'n' denotes the mini-batch dimension
|
|
/// - @c 'c' denotes a channels dimension
|
|
/// - When there are multiple channel dimensions (for example,
|
|
/// in convolution weights tensor), @c 'i' and @c 'o' denote dimensions
|
|
/// of input and output channels
|
|
/// - @c 'g' denotes a groups dimension for convolution weights
|
|
/// - @c 'd', @c 'h', and @c 'w' denote spatial depth, height, and width
|
|
/// respectively
|
|
///
|
|
/// See @ref dnnl_format_tag_t for a detailed description.
|
|
enum class format_tag {
|
|
/// Undefined memory format tag
|
|
undef = dnnl_format_tag_undef,
|
|
/// Placeholder memory format tag. Used to instruct the primitive to
|
|
/// select a format automatically.
|
|
any = dnnl_format_tag_any,
|
|
|
|
/// plain 1D tensor
|
|
a = dnnl_a,
|
|
|
|
/// plain 2D tensor
|
|
ab = dnnl_ab,
|
|
/// permuted 2D tensor
|
|
ba = dnnl_ba,
|
|
|
|
/// plain 3D tensor
|
|
abc = dnnl_abc,
|
|
/// permuted 3D tensor
|
|
acb = dnnl_acb,
|
|
/// permuted 3D tensor
|
|
bac = dnnl_bac,
|
|
/// permuted 3D tensor
|
|
bca = dnnl_bca,
|
|
/// permuted 3D tensor
|
|
cba = dnnl_cba,
|
|
|
|
/// plain 4D tensor
|
|
abcd = dnnl_abcd,
|
|
/// permuted 4D tensor
|
|
abdc = dnnl_abdc,
|
|
/// permuted 4D tensor
|
|
acbd = dnnl_acbd,
|
|
/// permuted 4D tensor
|
|
acdb = dnnl_acdb,
|
|
/// permuted 4D tensor
|
|
adbc = dnnl_adbc,
|
|
/// permuted 4D tensor
|
|
bacd = dnnl_bacd,
|
|
/// permuted 4D tensor
|
|
bcda = dnnl_bcda,
|
|
/// permuted 4D tensor
|
|
cdba = dnnl_cdba,
|
|
/// permuted 4D tensor
|
|
dcab = dnnl_dcab,
|
|
|
|
/// plain 5D tensor
|
|
abcde = dnnl_abcde,
|
|
/// permuted 5D tensor
|
|
abdec = dnnl_abdec,
|
|
/// permuted 5D tensor
|
|
acbde = dnnl_acbde,
|
|
/// permuted 5D tensor
|
|
acdeb = dnnl_acdeb,
|
|
/// permuted 5D tensor
|
|
bacde = dnnl_bacde,
|
|
/// permuted 5D tensor
|
|
bcdea = dnnl_bcdea,
|
|
/// permuted 5D tensor
|
|
cdeba = dnnl_cdeba,
|
|
/// permuted 5D tensor
|
|
decab = dnnl_decab,
|
|
/// permuted 5D tensor
|
|
abced = dnnl_abced,
|
|
|
|
/// plain 6D tensor
|
|
abcdef = dnnl_abcdef,
|
|
/// permuted 6D tensor
|
|
abdfce = dnnl_abdfce,
|
|
/// permuted 6D tensor
|
|
acbdef = dnnl_acbdef,
|
|
/// permuted 6D tensor
|
|
abdefc = dnnl_abdefc,
|
|
/// permuted 6D tensor
|
|
defcab = dnnl_defcab,
|
|
/// permuted 6D tensor
|
|
abcdfe = dnnl_abcdfe,
|
|
|
|
/// plain 7D tensor
|
|
abcdefg = dnnl_abcdefg,
|
|
/// permuted 7D tensor
|
|
abcdegf = dnnl_abcdegf,
|
|
|
|
/// plain 8D tensor
|
|
abcdefgh = dnnl_abcdefgh,
|
|
/// permuted 8D tensor
|
|
abcdefhg = dnnl_abcdefhg,
|
|
|
|
/// plain 9D tensor
|
|
abcdefghi = dnnl_abcdefghi,
|
|
/// permuted 9D tensor
|
|
abcdefgih = dnnl_abcdefgih,
|
|
|
|
/// plain 10D tensor
|
|
abcdefghij = dnnl_abcdefghij,
|
|
/// permuted 10D tensor
|
|
abcdefghji = dnnl_abcdefghji,
|
|
|
|
/// plain 11D tensor
|
|
abcdefghijk = dnnl_abcdefghijk,
|
|
/// permuted 11D tensor
|
|
abcdefghikj = dnnl_abcdefghikj,
|
|
|
|
/// plain 12D tensor
|
|
abcdefghijkl = dnnl_abcdefghijkl,
|
|
/// permuted 12D tensor
|
|
abcdefghijlk = dnnl_abcdefghijlk,
|
|
|
|
/// 1D tensor; an alias for #dnnl::memory::format_tag::a
|
|
x = a,
|
|
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ab
|
|
nc = ab,
|
|
/// 2D CNN activations tensor; an alias for #dnnl::memory::format_tag::ba
|
|
cn = ba,
|
|
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ab
|
|
tn = ab,
|
|
/// 2D RNN statistics tensor; an alias for #dnnl::memory::format_tag::ba
|
|
nt = ba,
|
|
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::abc
|
|
ncw = abc,
|
|
/// 3D CNN activations tensor; an alias for #dnnl::memory::format_tag::acb
|
|
nwc = acb,
|
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcd
|
|
nchw = abcd,
|
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdb
|
|
nhwc = acdb,
|
|
/// 4D CNN activations tensor; an alias for #dnnl::memory::format_tag::bcda
|
|
chwn = bcda,
|
|
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::abcde
|
|
ncdhw = abcde,
|
|
/// 5D CNN activations tensor; an alias for #dnnl::memory::format_tag::acdeb
|
|
ndhwc = acdeb,
|
|
|
|
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ab
|
|
oi = ab,
|
|
/// 2D CNN weights tensor; an alias for #dnnl::memory::format_tag::ba
|
|
io = ba,
|
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::abc
|
|
oiw = abc,
|
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::acb
|
|
owi = acb,
|
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::cba
|
|
wio = cba,
|
|
/// 3D CNN weights tensor; an alias for #dnnl::memory::format_tag::bca
|
|
iwo = bca,
|
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcd
|
|
oihw = abcd,
|
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdba
|
|
hwio = cdba,
|
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdb
|
|
ohwi = acdb,
|
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcda
|
|
ihwo = bcda,
|
|
/// 4D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacd
|
|
iohw = bacd,
|
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::abcde
|
|
oidhw = abcde,
|
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::cdeba
|
|
dhwio = cdeba,
|
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::acdeb
|
|
odhwi = acdeb,
|
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bacde
|
|
iodhw = bacde,
|
|
/// 5D CNN weights tensor; an alias for #dnnl::memory::format_tag::bcdea
|
|
idhwo = bcdea,
|
|
|
|
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcd
|
|
goiw = abcd,
|
|
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdc
|
|
gowi = abdc,
|
|
/// 4D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::dcab
|
|
wigo = dcab,
|
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdec
|
|
gohwi = abdec,
|
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcde
|
|
goihw = abcde,
|
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::decab
|
|
hwigo = decab,
|
|
/// 5D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::acbde
|
|
giohw = acbde,
|
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
|
|
goidhw = abcdef,
|
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abcdef
|
|
giodhw = acbdef,
|
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::abdefc
|
|
godhwi = abdefc,
|
|
/// 6D CNN weights tensor with groups; an alias for #dnnl::memory::format_tag::defcab
|
|
dhwigo = defcab,
|
|
|
|
/// 3D RNN data tensor in the format (seq_length, batch, input
|
|
/// channels); an alias for #dnnl::memory::format_tag::abc.
|
|
tnc = abc,
|
|
/// 3D RNN data tensor in the format (batch, seq_length, input
|
|
/// channels); an alias for #dnnl::memory::format_tag::bac.
|
|
ntc = bac,
|
|
/// 4D RNN states tensor in the format (num_layers, num_directions,
|
|
/// batch, state channels); an alias for #dnnl::memory::format_tag::abcd.
|
|
ldnc = abcd,
|
|
/// 5D RNN weights tensor in the format (num_layers, num_directions,
|
|
/// input_channels, num_gates, output_channels);
|
|
/// an alias for #dnnl::memory::format_tag::abcde.
|
|
///
|
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
|
/// and output gate.
|
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
|
ldigo = abcde,
|
|
/// 5D RNN weights tensor in the format (num_layers, num_directions,
|
|
/// num_gates, output_channels, input_channels);
|
|
/// an alias for #dnnl::memory::format_tag::abdec.
|
|
///
|
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
|
/// and output gate.
|
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
|
ldgoi = abdec,
|
|
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
|
|
/// num_channels_in_hidden_state, num_channels_in_recurrent_projection);
|
|
/// an alias for #dnnl::memory::format_tag::abcd.
|
|
ldio = abcd,
|
|
/// 4D LSTM projection tensor in the format (num_layers, num_directions,
|
|
/// num_channels_in_recurrent_projection, num_channels_in_hidden_state);
|
|
/// an alias for #dnnl::memory::format_tag::abdc.
|
|
ldoi = abdc,
|
|
/// 4D RNN bias tensor in the format (num_layers, num_directions,
|
|
/// num_gates, output_channels);
|
|
/// an alias for #dnnl::memory::format_tag::abcd.
|
|
///
|
|
/// - For LSTM cells, the gates order is input, forget, candidate
|
|
/// and output gate.
|
|
/// - For GRU cells, the gates order is update, reset and output gate.
|
|
ldgo = abcd,
|
|
|
|
// Opaque blocked formats
|
|
|
|
AB16b16a = dnnl_AB16b16a,
|
|
AB16b32a = dnnl_AB16b32a,
|
|
AB16b48a = dnnl_AB16b48a,
|
|
AB16b64a = dnnl_AB16b64a,
|
|
AB8b16a2b = dnnl_AB8b16a2b,
|
|
AB8b32a2b = dnnl_AB8b32a2b,
|
|
AB8b64a2b = dnnl_AB8b64a2b,
|
|
AB4b16a4b = dnnl_AB4b16a4b,
|
|
AB4b32a4b = dnnl_AB4b32a4b,
|
|
AB4b64a4b = dnnl_AB4b64a4b,
|
|
AB16b16a4b = dnnl_AB16b16a4b,
|
|
AB16b32a4b = dnnl_AB16b32a4b,
|
|
AB16b48a4b = dnnl_AB16b48a4b,
|
|
AB16b64a4b = dnnl_AB16b64a4b,
|
|
AB16b16a2b = dnnl_AB16b16a2b,
|
|
AB16b32a2b = dnnl_AB16b32a2b,
|
|
AB16b48a2b = dnnl_AB16b48a2b,
|
|
AB16b64a2b = dnnl_AB16b64a2b,
|
|
Ab4a = dnnl_Ab4a,
|
|
Ab8a = dnnl_Ab8a,
|
|
Ab32a = dnnl_Ab32a,
|
|
Abc16a = dnnl_Abc16a,
|
|
ABc16a16b = dnnl_ABc16a16b,
|
|
ABc4a4b = dnnl_ABc4a4b,
|
|
aBc16b = dnnl_aBc16b,
|
|
aBc32b = dnnl_aBc32b,
|
|
ABc16b16a = dnnl_ABc16b16a,
|
|
AcB16b16a = dnnl_AcB16b16a,
|
|
ABc16b32a = dnnl_ABc16b32a,
|
|
AcB16b32a = dnnl_AcB16b32a,
|
|
ABc16b48a = dnnl_ABc16b48a,
|
|
AcB16b48a = dnnl_AcB16b48a,
|
|
ABc16b64a = dnnl_ABc16b64a,
|
|
AcB16b64a = dnnl_AcB16b64a,
|
|
Abc4a = dnnl_Abc4a,
|
|
aBc4b = dnnl_aBc4b,
|
|
ABc4b16a4b = dnnl_ABc4b16a4b,
|
|
AcB4b16a4b = dnnl_AcB4b16a4b,
|
|
ABc4b32a4b = dnnl_ABc4b32a4b,
|
|
AcB4b32a4b = dnnl_AcB4b32a4b,
|
|
ABc4b64a4b = dnnl_ABc4b64a4b,
|
|
AcB4b64a4b = dnnl_AcB4b64a4b,
|
|
ABc2b8a4b = dnnl_ABc2b8a4b,
|
|
ABc16a16b2a = dnnl_ABc16a16b2a,
|
|
ABc16b16a4b = dnnl_ABc16b16a4b,
|
|
ABc16b32a4b = dnnl_ABc16b32a4b,
|
|
ABc16b48a4b = dnnl_ABc16b48a4b,
|
|
ABc16b64a4b = dnnl_ABc16b64a4b,
|
|
ABc16b16a2b = dnnl_ABc16b16a2b,
|
|
ABc16b32a2b = dnnl_ABc16b32a2b,
|
|
ABc16b48a2b = dnnl_ABc16b48a2b,
|
|
ABc16b64a2b = dnnl_ABc16b64a2b,
|
|
ABc4b4a = dnnl_ABc4b4a,
|
|
ABc8a16b2a = dnnl_ABc8a16b2a,
|
|
ABc8a8b = dnnl_ABc8a8b,
|
|
ABc8a4b = dnnl_ABc8a4b,
|
|
aBc8b = dnnl_aBc8b,
|
|
ABc8b16a2b = dnnl_ABc8b16a2b,
|
|
AcB8b16a2b = dnnl_AcB8b16a2b,
|
|
ABc8b32a2b = dnnl_ABc8b32a2b,
|
|
AcB8b32a2b = dnnl_AcB8b32a2b,
|
|
ABc8b64a2b = dnnl_ABc8b64a2b,
|
|
AcB8b64a2b = dnnl_AcB8b64a2b,
|
|
ABc8b8a = dnnl_ABc8b8a,
|
|
AcB8b8a = dnnl_AcB8b8a,
|
|
Abcd8a = dnnl_Abcd8a,
|
|
Abcd16a = dnnl_Abcd16a,
|
|
Abcd32a = dnnl_Abcd32a,
|
|
ABcd16a16b = dnnl_ABcd16a16b,
|
|
aBcd16b = dnnl_aBcd16b,
|
|
aBcd32b = dnnl_aBcd32b,
|
|
ABcd16b16a = dnnl_ABcd16b16a,
|
|
AcdB16b16a = dnnl_AcdB16b16a,
|
|
ABcd16b32a = dnnl_ABcd16b32a,
|
|
AcdB16b32a = dnnl_AcdB16b32a,
|
|
ABcd16b48a = dnnl_ABcd16b48a,
|
|
AcdB16b48a = dnnl_AcdB16b48a,
|
|
ABcd16b64a = dnnl_ABcd16b64a,
|
|
AcdB16b64a = dnnl_AcdB16b64a,
|
|
aBCd16b16c = dnnl_aBCd16b16c,
|
|
aBCd16c16b = dnnl_aBCd16c16b,
|
|
Abcd4a = dnnl_Abcd4a,
|
|
aBcd4b = dnnl_aBcd4b,
|
|
ABcd4b16a4b = dnnl_ABcd4b16a4b,
|
|
AcdB4b16a4b = dnnl_AcdB4b16a4b,
|
|
ABcd4b32a4b = dnnl_ABcd4b32a4b,
|
|
AcdB4b32a4b = dnnl_AcdB4b32a4b,
|
|
ABcd4b64a4b = dnnl_ABcd4b64a4b,
|
|
AcdB4b64a4b = dnnl_AcdB4b64a4b,
|
|
ABcd2b8a4b = dnnl_ABcd2b8a4b,
|
|
ABcd4b4a = dnnl_ABcd4b4a,
|
|
ABcd4a4b = dnnl_ABcd4a4b,
|
|
aBCd4c16b4c = dnnl_aBCd4c16b4c,
|
|
aBCd2c8b4c = dnnl_aBCd2c8b4c,
|
|
ABcd16a16b2a = dnnl_ABcd16a16b2a,
|
|
ABcd16b16a4b = dnnl_ABcd16b16a4b,
|
|
ABcd16b32a4b = dnnl_ABcd16b32a4b,
|
|
ABcd16b48a4b = dnnl_ABcd16b48a4b,
|
|
ABcd16b64a4b = dnnl_ABcd16b64a4b,
|
|
ABcd16b16a2b = dnnl_ABcd16b16a2b,
|
|
ABcd16b32a2b = dnnl_ABcd16b32a2b,
|
|
ABcd16b48a2b = dnnl_ABcd16b48a2b,
|
|
ABcd16b64a2b = dnnl_ABcd16b64a2b,
|
|
aBCd16b16c2b = dnnl_aBCd16b16c2b,
|
|
aBCd16c16b4c = dnnl_aBCd16c16b4c,
|
|
aBCd16c16b2c = dnnl_aBCd16c16b2c,
|
|
aBCd4c4b = dnnl_aBCd4c4b,
|
|
aBCd4b4c = dnnl_aBCd4b4c,
|
|
ABcd8a16b2a = dnnl_ABcd8a16b2a,
|
|
ABcd8a8b = dnnl_ABcd8a8b,
|
|
ABcd8a4b = dnnl_ABcd8a4b,
|
|
ABcd8a2b = dnnl_ABcd8a2b,
|
|
/// 4D tensor blocked by 2nd dimension with block size 8
|
|
aBcd8b = dnnl_aBcd8b,
|
|
ABcd8b16a2b = dnnl_ABcd8b16a2b,
|
|
AcdB8b16a2b = dnnl_AcdB8b16a2b,
|
|
ABcd8b32a2b = dnnl_ABcd8b32a2b,
|
|
AcdB8b32a2b = dnnl_AcdB8b32a2b,
|
|
ABcd8b64a2b = dnnl_ABcd8b64a2b,
|
|
AcdB8b64a2b = dnnl_AcdB8b64a2b,
|
|
aBCd8b16c2b = dnnl_aBCd8b16c2b,
|
|
/// 4D tensor blocked by 1st and 2nd dimension with block size 8
|
|
ABcd8b8a = dnnl_ABcd8b8a,
|
|
AcdB8b8a = dnnl_AcdB8b8a,
|
|
aBCd8b8c = dnnl_aBCd8b8c,
|
|
aBCd8b4c = dnnl_aBCd8b4c,
|
|
aBCd8c16b2c = dnnl_aBCd8c16b2c,
|
|
aBCd8c8b = dnnl_aBCd8c8b,
|
|
Abcde16a = dnnl_Abcde16a,
|
|
Abcde32a = dnnl_Abcde32a,
|
|
ABcde16a16b = dnnl_ABcde16a16b,
|
|
aBcde16b = dnnl_aBcde16b,
|
|
aBcde32b = dnnl_aBcde32b,
|
|
ABcde16b16a = dnnl_ABcde16b16a,
|
|
AcdeB16b16a = dnnl_AcdeB16b16a,
|
|
ABcde16b32a = dnnl_ABcde16b32a,
|
|
AcdeB16b32a = dnnl_AcdeB16b32a,
|
|
ABcde16b48a = dnnl_ABcde16b48a,
|
|
AcdeB16b48a = dnnl_AcdeB16b48a,
|
|
ABcde16b64a = dnnl_ABcde16b64a,
|
|
AcdeB16b64a = dnnl_AcdeB16b64a,
|
|
aBCde16b16c = dnnl_aBCde16b16c,
|
|
aBCde16c16b = dnnl_aBCde16c16b,
|
|
aBCde2c8b4c = dnnl_aBCde2c8b4c,
|
|
Abcde4a = dnnl_Abcde4a,
|
|
aBcde4b = dnnl_aBcde4b,
|
|
ABcde4b4a = dnnl_ABcde4b4a,
|
|
ABcde4a4b = dnnl_ABcde4a4b,
|
|
aBCde4b4c = dnnl_aBCde4b4c,
|
|
aBCde4c16b4c = dnnl_aBCde4c16b4c,
|
|
aBCde16b16c2b = dnnl_aBCde16b16c2b,
|
|
aBCde16c16b4c = dnnl_aBCde16c16b4c,
|
|
aBCde16c16b2c = dnnl_aBCde16c16b2c,
|
|
aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
|
|
aBCde4c4b = dnnl_aBCde4c4b,
|
|
Abcde8a = dnnl_Abcde8a,
|
|
ABcde8a8b = dnnl_ABcde8a8b,
|
|
ABcde8a4b = dnnl_ABcde8a4b,
|
|
aBcde8b = dnnl_aBcde8b,
|
|
ABcde8b16a2b = dnnl_ABcde8b16a2b,
|
|
AcdeB8b16a2b = dnnl_AcdeB8b16a2b,
|
|
ABcde8b32a2b = dnnl_ABcde8b32a2b,
|
|
AcdeB8b32a2b = dnnl_AcdeB8b32a2b,
|
|
ABcde8b64a2b = dnnl_ABcde8b64a2b,
|
|
AcdeB8b64a2b = dnnl_AcdeB8b64a2b,
|
|
ABcde4b16a4b = dnnl_ABcde4b16a4b,
|
|
AcdeB4b16a4b = dnnl_AcdeB4b16a4b,
|
|
ABcde4b32a4b = dnnl_ABcde4b32a4b,
|
|
AcdeB4b32a4b = dnnl_AcdeB4b32a4b,
|
|
ABcde4b64a4b = dnnl_ABcde4b64a4b,
|
|
AcdeB4b64a4b = dnnl_AcdeB4b64a4b,
|
|
ABcde16b16a4b = dnnl_ABcde16b16a4b,
|
|
ABcde16b32a4b = dnnl_ABcde16b32a4b,
|
|
ABcde16b48a4b = dnnl_ABcde16b48a4b,
|
|
ABcde16b64a4b = dnnl_ABcde16b64a4b,
|
|
ABcde16b16a2b = dnnl_ABcde16b16a2b,
|
|
ABcde16b32a2b = dnnl_ABcde16b32a2b,
|
|
ABcde16b48a2b = dnnl_ABcde16b48a2b,
|
|
ABcde16b64a2b = dnnl_ABcde16b64a2b,
|
|
ABcde2b8a4b = dnnl_ABcde2b8a4b,
|
|
aBCde8b16c2b = dnnl_aBCde8b16c2b,
|
|
ABcde8b8a = dnnl_ABcde8b8a,
|
|
AcdeB8b8a = dnnl_AcdeB8b8a,
|
|
aBCde8b8c = dnnl_aBCde8b8c,
|
|
aBCde8b4c = dnnl_aBCde8b4c,
|
|
ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
|
|
ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
|
|
aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
|
|
aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
|
|
aBCde8c16b2c = dnnl_aBCde8c16b2c,
|
|
aBCde8c8b = dnnl_aBCde8c8b,
|
|
aBcdef16b = dnnl_aBcdef16b,
|
|
aBCdef16b16c = dnnl_aBCdef16b16c,
|
|
aBCdef16c16b = dnnl_aBCdef16c16b,
|
|
aBcdef4b = dnnl_aBcdef4b,
|
|
aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
|
|
aBCdef4c4b = dnnl_aBCdef4c4b,
|
|
aBCdef4b4c = dnnl_aBCdef4b4c,
|
|
aBCdef8b8c = dnnl_aBCdef8b8c,
|
|
aBCdef8b4c = dnnl_aBCdef8b4c,
|
|
aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
|
|
aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
|
|
aBCdef8c8b = dnnl_aBCdef8c8b,
|
|
aBdc16b = dnnl_aBdc16b,
|
|
aBdc4b = dnnl_aBdc4b,
|
|
aBdc8b = dnnl_aBdc8b,
|
|
aBdC8b2c = dnnl_aBdC8b2c,
|
|
aBdC8b4c = dnnl_aBdC8b4c,
|
|
aBdec16b = dnnl_aBdec16b,
|
|
aBdec4b = dnnl_aBdec4b,
|
|
aBdec8b = dnnl_aBdec8b,
|
|
aBdeC8b2c = dnnl_aBdeC8b2c,
|
|
aBdeC8b4c = dnnl_aBdeC8b4c,
|
|
aBdefc16b = dnnl_aBdefc16b,
|
|
aCBdef16c16b = dnnl_aCBdef16c16b,
|
|
aCBdef8b8c = dnnl_aCBdef8b8c,
|
|
aCBdef16b16c = dnnl_aCBdef16b16c,
|
|
aBdefc4b = dnnl_aBdefc4b,
|
|
aBdefc8b = dnnl_aBdefc8b,
|
|
aBdefC8b2c = dnnl_aBdefC8b2c,
|
|
aBdefC8b4c = dnnl_aBdefC8b4c,
|
|
Acb16a = dnnl_Acb16a,
|
|
Acb4a = dnnl_Acb4a,
|
|
Acb8a = dnnl_Acb8a,
|
|
AcB8a2b = dnnl_AcB8a2b,
|
|
AcB8a4b = dnnl_AcB8a4b,
|
|
aCBd8b8c = dnnl_aCBd8b8c,
|
|
aCBd16b16c = dnnl_aCBd16b16c,
|
|
aCBd16c16b = dnnl_aCBd16c16b,
|
|
aCBde8b8c = dnnl_aCBde8b8c,
|
|
aCBde16b16c = dnnl_aCBde16b16c,
|
|
aCBde16c16b = dnnl_aCBde16c16b,
|
|
Acdb16a = dnnl_Acdb16a,
|
|
Acdb4a = dnnl_Acdb4a,
|
|
Acdb8a = dnnl_Acdb8a,
|
|
AcdB8a2b = dnnl_AcdB8a2b,
|
|
AcdB8a4b = dnnl_AcdB8a4b,
|
|
Acdeb16a = dnnl_Acdeb16a,
|
|
Acdeb4a = dnnl_Acdeb4a,
|
|
Acdeb8a = dnnl_Acdeb8a,
|
|
AcdeB8a2b = dnnl_AcdeB8a2b,
|
|
AcdeB8a4b = dnnl_AcdeB8a4b,
|
|
BAc8a8b = dnnl_BAc8a8b,
|
|
BAc16a16b = dnnl_BAc16a16b,
|
|
BAc16b16a = dnnl_BAc16b16a,
|
|
BAcd8a8b = dnnl_BAcd8a8b,
|
|
BAcd16a16b = dnnl_BAcd16a16b,
|
|
BAcd16b16a = dnnl_BAcd16b16a,
|
|
ABcd32a32b = dnnl_ABcd32a32b,
|
|
BAcde16b16a = dnnl_BAcde16b16a,
|
|
BAcde8a8b = dnnl_BAcde8a8b,
|
|
BAcde16a16b = dnnl_BAcde16a16b,
|
|
aBdec32b = dnnl_aBdec32b,
|
|
Abcdef16a = dnnl_Abcdef16a,
|
|
Abcdef32a = dnnl_Abcdef32a,
|
|
Acdb32a = dnnl_Acdb32a,
|
|
aBCd2b4c2b = dnnl_aBCd2b4c2b,
|
|
aBCde2b4c2b = dnnl_aBCde2b4c2b,
|
|
aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
|
|
aBCd2c4b2c = dnnl_aBCd2c4b2c,
|
|
aBCde2c4b2c = dnnl_aBCde2c4b2c,
|
|
aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
|
|
aBCd4b8c2b = dnnl_aBCd4b8c2b,
|
|
aBCde4b8c2b = dnnl_aBCde4b8c2b,
|
|
aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
|
|
aBCd4c8b2c = dnnl_aBCd4c8b2c,
|
|
aBCde4c8b2c = dnnl_aBCde4c8b2c,
|
|
aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
|
|
AB32a32b8a4b = dnnl_AB32a32b8a4b,
|
|
AB32a32b8a2b = dnnl_AB32a32b8a2b,
|
|
AB8a4b = dnnl_AB8a4b,
|
|
AB8a2b = dnnl_AB8a2b,
|
|
abDc16d = dnnl_abDc16d,
|
|
abDc32d = dnnl_abDc32d,
|
|
abDC16d4c = dnnl_abDC16d4c,
|
|
abDC32d4c = dnnl_abDC32d4c,
|
|
abCd32c = dnnl_abCd32c,
|
|
abdEc16e = dnnl_abdEc16e,
|
|
abdEc32e = dnnl_abdEc32e,
|
|
abdEC16e4c = dnnl_abdEC16e4c,
|
|
abdEC32e2c = dnnl_abdEC32e2c,
|
|
abdEC32e4c = dnnl_abdEC32e4c,
|
|
abdCe16c = dnnl_abdCe16c,
|
|
abdCe32c = dnnl_abdCe32c,
|
|
abdCE32c2e = dnnl_abdCE32c2e,
|
|
aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
|
|
aBdC16b4c = dnnl_aBdC16b4c,
|
|
aBdeC16b4c = dnnl_aBdeC16b4c,
|
|
AcB16a4b = dnnl_AcB16a4b,
|
|
AcdB16a2b = dnnl_AcdB16a2b,
|
|
aBdefC16b4c = dnnl_aBdefC16b4c,
|
|
AcdeB16a4b = dnnl_AcdeB16a4b,
|
|
|
|
Acb32a = dnnl_Acb32a,
|
|
AcB32a2b = dnnl_AcB32a2b,
|
|
AcB32a4b = dnnl_AcB32a4b,
|
|
Acb48a = dnnl_Acb48a,
|
|
AcB48a2b = dnnl_AcB48a2b,
|
|
AcB48a4b = dnnl_AcB48a4b,
|
|
Acb64a = dnnl_Acb64a,
|
|
AcB64a2b = dnnl_AcB64a2b,
|
|
AcB64a4b = dnnl_AcB64a4b,
|
|
cBa2b = dnnl_cBa2b,
|
|
cBa4b = dnnl_cBa4b,
|
|
aBdc32b = dnnl_aBdc32b,
|
|
aBdC32b2c = dnnl_aBdC32b2c,
|
|
aBdC32b4c = dnnl_aBdC32b4c,
|
|
aBdc48b = dnnl_aBdc48b,
|
|
aBdC48b2c = dnnl_aBdC48b2c,
|
|
aBdC48b4c = dnnl_aBdC48b4c,
|
|
aBdc64b = dnnl_aBdc64b,
|
|
aBdC64b2c = dnnl_aBdC64b2c,
|
|
aBdC64b4c = dnnl_aBdC64b4c,
|
|
adcb = dnnl_adcb,
|
|
adCb2c = dnnl_adCb2c,
|
|
adCb4c = dnnl_adCb4c,
|
|
AcdB32a2b = dnnl_AcdB32a2b,
|
|
AcdB32a4b = dnnl_AcdB32a4b,
|
|
Acdb48a = dnnl_Acdb48a,
|
|
AcdB48a2b = dnnl_AcdB48a2b,
|
|
AcdB48a4b = dnnl_AcdB48a4b,
|
|
Acdb64a = dnnl_Acdb64a,
|
|
AcdB64a2b = dnnl_AcdB64a2b,
|
|
AcdB64a4b = dnnl_AcdB64a4b,
|
|
cdBa2b = dnnl_cdBa2b,
|
|
cdBa4b = dnnl_cdBa4b,
|
|
aBdeC32b2c = dnnl_aBdeC32b2c,
|
|
aBdeC32b4c = dnnl_aBdeC32b4c,
|
|
aBdec48b = dnnl_aBdec48b,
|
|
aBdeC48b2c = dnnl_aBdeC48b2c,
|
|
aBdeC48b4c = dnnl_aBdeC48b4c,
|
|
aBdec64b = dnnl_aBdec64b,
|
|
aBdeC64b2c = dnnl_aBdeC64b2c,
|
|
aBdeC64b4c = dnnl_aBdeC64b4c,
|
|
adecb = dnnl_adecb,
|
|
adeCb2c = dnnl_adeCb2c,
|
|
adeCb4c = dnnl_adeCb4c,
|
|
Acdeb32a = dnnl_Acdeb32a,
|
|
AcdeB32a2b = dnnl_AcdeB32a2b,
|
|
AcdeB32a4b = dnnl_AcdeB32a4b,
|
|
Acdeb48a = dnnl_Acdeb48a,
|
|
AcdeB48a2b = dnnl_AcdeB48a2b,
|
|
AcdeB48a4b = dnnl_AcdeB48a4b,
|
|
Acdeb64a = dnnl_Acdeb64a,
|
|
AcdeB64a2b = dnnl_AcdeB64a2b,
|
|
AcdeB64a4b = dnnl_AcdeB64a4b,
|
|
cdeBa2b = dnnl_cdeBa2b,
|
|
cdeBa4b = dnnl_cdeBa4b,
|
|
aBdefc32b = dnnl_aBdefc32b,
|
|
aBdefC32b2c = dnnl_aBdefC32b2c,
|
|
aBdefC32b4c = dnnl_aBdefC32b4c,
|
|
aBdefc48b = dnnl_aBdefc48b,
|
|
aBdefC48b2c = dnnl_aBdefC48b2c,
|
|
aBdefC48b4c = dnnl_aBdefC48b4c,
|
|
aBdefc64b = dnnl_aBdefc64b,
|
|
aBdefC64b2c = dnnl_aBdefC64b2c,
|
|
aBdefC64b4c = dnnl_aBdefC64b4c,
|
|
adefcb = dnnl_adefcb,
|
|
adefCb2c = dnnl_adefCb2c,
|
|
adefCb4c = dnnl_adefCb4c,
|
|
ABc32a32b = dnnl_ABc32a32b,
|
|
BAc8a16b2a = dnnl_BAc8a16b2a,
|
|
BAcd8a16b2a = dnnl_BAcd8a16b2a,
|
|
ABcde8a16b2a = dnnl_ABcde8a16b2a,
|
|
aCBd8b16c2b = dnnl_aCBd8b16c2b,
|
|
BAcde8a16b2a = dnnl_BAcde8a16b2a,
|
|
aCBde8b16c2b = dnnl_aCBde8b16c2b,
|
|
ABcde32a32b = dnnl_ABcde32a32b,
|
|
ABc4a8b8a4b = dnnl_ABc4a8b8a4b,
|
|
ABcde4a8b8a4b = dnnl_ABcde4a8b8a4b,
|
|
BAc4b8a8b4a = dnnl_BAc4b8a8b4a,
|
|
BAcd4b8a8b4a = dnnl_BAcd4b8a8b4a,
|
|
BAcde4b8a8b4a = dnnl_BAcde4b8a8b4a,
|
|
aBCd4b8c8b4c = dnnl_aBCd4b8c8b4c,
|
|
aBCdef4b8c8b4c = dnnl_aBCdef4b8c8b4c,
|
|
aBCdef8b16c2b = dnnl_aBCdef8b16c2b,
|
|
aCBdef8b16c2b = dnnl_aCBdef8b16c2b,
|
|
aBdC16b2c = dnnl_aBdC16b2c,
|
|
aBdeC16b2c = dnnl_aBdeC16b2c,
|
|
aBdefC16b2c = dnnl_aBdefC16b2c,
|
|
aBedc16b = dnnl_aBedc16b,
|
|
AcB16a2b = dnnl_AcB16a2b,
|
|
AcdB16a4b = dnnl_AcdB16a4b,
|
|
AcdeB16a2b = dnnl_AcdeB16a2b,
|
|
Adcb16a = dnnl_Adcb16a,
|
|
aCBd4c8b8c4b = dnnl_aCBd4c8b8c4b,
|
|
aCBde4c8b8c4b = dnnl_aCBde4c8b8c4b,
|
|
aCBdef4c8b8c4b = dnnl_aCBdef4c8b8c4b,
|
|
ABc32a16b = dnnl_ABc32a16b,
|
|
ABcd16a32b = dnnl_ABcd16a32b,
|
|
ABcd32a16b = dnnl_ABcd32a16b,
|
|
ABcde32a16b = dnnl_ABcde32a16b,
|
|
AB48a16b = dnnl_AB48a16b,
|
|
AB48a32b = dnnl_AB48a32b,
|
|
ABc40a16b = dnnl_ABc40a16b,
|
|
ABc40a32b = dnnl_ABc40a32b,
|
|
aBC48b16c = dnnl_aBC48b16c,
|
|
aBC48b32c = dnnl_aBC48b32c,
|
|
ABcd40a16b = dnnl_ABcd40a16b,
|
|
ABcd40a32b = dnnl_ABcd40a32b,
|
|
BA16a16b = dnnl_BA16a16b,
|
|
BA16a32b = dnnl_BA16a32b,
|
|
BA16a48b = dnnl_BA16a48b,
|
|
BA16a64b = dnnl_BA16a64b,
|
|
BA16a16b2a = dnnl_BA16a16b2a,
|
|
BA16a32b2a = dnnl_BA16a32b2a,
|
|
BA16a48b2a = dnnl_BA16a48b2a,
|
|
BA16a64b2a = dnnl_BA16a64b2a,
|
|
BA16a16b4a = dnnl_BA16a16b4a,
|
|
BA16a32b4a = dnnl_BA16a32b4a,
|
|
BA16a48b4a = dnnl_BA16a48b4a,
|
|
BA16a64b4a = dnnl_BA16a64b4a,
|
|
BA24b8a = dnnl_BA24b8a,
|
|
aCB24c8b = dnnl_aCB24c8b,
|
|
abDC24d8c = dnnl_abDC24d8c,
|
|
decbA16a = dnnl_decbA16a,
|
|
decbA8a = dnnl_decbA8a,
|
|
defcbA16a = dnnl_defcbA16a,
|
|
defcbA8a = dnnl_defcbA8a,
|
|
aCB16b16c = dnnl_aCB16b16c,
|
|
aCB16b32c = dnnl_aCB16b32c,
|
|
aCB16b48c = dnnl_aCB16b48c,
|
|
aCB16b64c = dnnl_aCB16b64c,
|
|
aCB16b16c2b = dnnl_aCB16b16c2b,
|
|
aCB16b32c2b = dnnl_aCB16b32c2b,
|
|
aCB16b48c2b = dnnl_aCB16b48c2b,
|
|
aCB16b64c2b = dnnl_aCB16b64c2b,
|
|
aCB16b16c4b = dnnl_aCB16b16c4b,
|
|
aCB16b32c4b = dnnl_aCB16b32c4b,
|
|
aCB16b48c4b = dnnl_aCB16b48c4b,
|
|
aCB16b64c4b = dnnl_aCB16b64c4b,
|
|
Acb24a = dnnl_Acb24a,
|
|
Acdb24a = dnnl_Acdb24a,
|
|
Acdeb24a = dnnl_Acdeb24a,
|
|
aBdc24b = dnnl_aBdc24b,
|
|
aBdec24b = dnnl_aBdec24b,
|
|
aBdefc24b = dnnl_aBdefc24b,
|
|
AcB24a2b = dnnl_AcB24a2b,
|
|
AcdB24a2b = dnnl_AcdB24a2b,
|
|
AcdeB24a2b = dnnl_AcdeB24a2b,
|
|
aBdC24b2c = dnnl_aBdC24b2c,
|
|
aBdeC24b2c = dnnl_aBdeC24b2c,
|
|
aBdefC24b2c = dnnl_aBdefC24b2c,
|
|
AcB24a4b = dnnl_AcB24a4b,
|
|
AcdB24a4b = dnnl_AcdB24a4b,
|
|
AcdeB24a4b = dnnl_AcdeB24a4b,
|
|
aBdC24b4c = dnnl_aBdC24b4c,
|
|
aBdeC24b4c = dnnl_aBdeC24b4c,
|
|
aBdefC24b4c = dnnl_aBdefC24b4c,
|
|
AB8b32a = dnnl_AB8b32a,
|
|
ABc8b32a = dnnl_ABc8b32a,
|
|
AcB8b32a = dnnl_AcB8b32a,
|
|
ABcd8b32a = dnnl_ABcd8b32a,
|
|
AcdB8b32a = dnnl_AcdB8b32a,
|
|
ABcde8b32a = dnnl_ABcde8b32a,
|
|
AcdeB8b32a = dnnl_AcdeB8b32a,
|
|
AB8b24a = dnnl_AB8b24a,
|
|
ABc8b24a = dnnl_ABc8b24a,
|
|
AcB8b24a = dnnl_AcB8b24a,
|
|
ABcd8b24a = dnnl_ABcd8b24a,
|
|
AcdB8b24a = dnnl_AcdB8b24a,
|
|
ABcde8b24a = dnnl_ABcde8b24a,
|
|
AcdeB8b24a = dnnl_AcdeB8b24a,
|
|
AB8b16a = dnnl_AB8b16a,
|
|
ABc8b16a = dnnl_ABc8b16a,
|
|
AcB8b16a = dnnl_AcB8b16a,
|
|
ABcd8b16a = dnnl_ABcd8b16a,
|
|
AcdB8b16a = dnnl_AcdB8b16a,
|
|
ABcde8b16a = dnnl_ABcde8b16a,
|
|
AcdeB8b16a = dnnl_AcdeB8b16a,
|
|
AB8b8a = dnnl_AB8b8a,
|
|
abDC8d8c = dnnl_abDC8d8c,
|
|
abDC16d8c = dnnl_abDC16d8c,
|
|
aCB8c8b = dnnl_aCB8c8b,
|
|
aCB16c8b = dnnl_aCB16c8b,
|
|
BA8b8a = dnnl_BA8b8a,
|
|
BA16b8a = dnnl_BA16b8a,
|
|
AB2a4b = dnnl_AB2a4b,
|
|
|
|
format_tag_last = dnnl_format_tag_last,
|
|
|
|
nCdhw16c = dnnl_nCdhw16c,
|
|
nCdhw4c = dnnl_nCdhw4c,
|
|
nCdhw8c = dnnl_nCdhw8c,
|
|
nChw16c = dnnl_nChw16c,
|
|
nChw4c = dnnl_nChw4c,
|
|
nChw8c = dnnl_nChw8c,
|
|
nCw16c = dnnl_nCw16c,
|
|
nCw4c = dnnl_nCw4c,
|
|
nCw8c = dnnl_nCw8c,
|
|
NCw16n16c = dnnl_NCw16n16c,
|
|
NChw16n16c = dnnl_NChw16n16c,
|
|
NCdhw16n16c = dnnl_NCdhw16n16c,
|
|
NCdhw32n32c = dnnl_NCdhw32n32c,
|
|
NChw32n32c = dnnl_NChw32n32c,
|
|
IOhw16i16o = dnnl_IOhw16i16o,
|
|
OI16i16o = dnnl_OI16i16o,
|
|
OI16i32o = dnnl_OI16i32o,
|
|
OI16i48o = dnnl_OI16i48o,
|
|
OI16i64o = dnnl_OI16i64o,
|
|
OI8i16o2i = dnnl_OI8i16o2i,
|
|
OI8i32o2i = dnnl_OI8i32o2i,
|
|
OI8i64o2i = dnnl_OI8i64o2i,
|
|
OI4i8o4i = dnnl_OI4i8o4i,
|
|
OI4i16o4i = dnnl_OI4i16o4i,
|
|
OI4i24o4i = dnnl_OI4i24o4i,
|
|
OI4i32o4i = dnnl_OI4i32o4i,
|
|
OI4i64o4i = dnnl_OI4i64o4i,
|
|
Ohwi32o = dnnl_Ohwi32o,
|
|
IOdhw16i16o = dnnl_IOdhw16i16o,
|
|
gIOhw16i16o = dnnl_gIOhw16i16o,
|
|
gOhwi32o = dnnl_gOhwi32o,
|
|
Goidhw16g = dnnl_Goidhw16g,
|
|
IOw8o8i = dnnl_IOw8o8i,
|
|
IOw16o16i = dnnl_IOw16o16i,
|
|
OIw16i16o = dnnl_OIw16i16o,
|
|
OwI16i16o = dnnl_OwI16i16o,
|
|
OIw16i32o = dnnl_OIw16i32o,
|
|
OwI16i32o = dnnl_OwI16i32o,
|
|
OIw16i48o = dnnl_OIw16i48o,
|
|
OwI16i48o = dnnl_OwI16i48o,
|
|
OIw16i64o = dnnl_OIw16i64o,
|
|
OwI16i64o = dnnl_OwI16i64o,
|
|
IOw16i16o = dnnl_IOw16i16o,
|
|
gIOw16i16o = dnnl_gIOw16i16o,
|
|
OIw16o16i = dnnl_OIw16o16i,
|
|
Oiw16o = dnnl_Oiw16o,
|
|
OIw4i8o4i = dnnl_OIw4i8o4i,
|
|
OwI4i8o4i = dnnl_OwI4i8o4i,
|
|
OIw4i16o4i = dnnl_OIw4i16o4i,
|
|
OwI4i16o4i = dnnl_OwI4i16o4i,
|
|
OIw4i24o4i = dnnl_OIw4i24o4i,
|
|
OwI4i24o4i = dnnl_OwI4i24o4i,
|
|
OIw4i32o4i = dnnl_OIw4i32o4i,
|
|
OwI4i32o4i = dnnl_OwI4i32o4i,
|
|
OIw4i64o4i = dnnl_OIw4i64o4i,
|
|
OwI4i64o4i = dnnl_OwI4i64o4i,
|
|
OIw2i8o4i = dnnl_OIw2i8o4i,
|
|
OIw4i4o = dnnl_OIw4i4o,
|
|
OIw4o4i = dnnl_OIw4o4i,
|
|
Oiw4o = dnnl_Oiw4o,
|
|
OIw8i16o2i = dnnl_OIw8i16o2i,
|
|
OwI8i16o2i = dnnl_OwI8i16o2i,
|
|
OIw8i32o2i = dnnl_OIw8i32o2i,
|
|
OwI8i32o2i = dnnl_OwI8i32o2i,
|
|
OIw8i64o2i = dnnl_OIw8i64o2i,
|
|
OwI8i64o2i = dnnl_OwI8i64o2i,
|
|
OIw8i8o = dnnl_OIw8i8o,
|
|
OwI8i8o = dnnl_OwI8i8o,
|
|
OIw8o16i2o = dnnl_OIw8o16i2o,
|
|
OIw8o8i = dnnl_OIw8o8i,
|
|
OIw8o4i = dnnl_OIw8o4i,
|
|
OIw16i16o4i = dnnl_OIw16i16o4i,
|
|
OIw16i32o4i = dnnl_OIw16i32o4i,
|
|
OIw16i48o4i = dnnl_OIw16i48o4i,
|
|
OIw16i64o4i = dnnl_OIw16i64o4i,
|
|
OIw16i16o2i = dnnl_OIw16i16o2i,
|
|
OIw16i32o2i = dnnl_OIw16i32o2i,
|
|
OIw16i48o2i = dnnl_OIw16i48o2i,
|
|
OIw16i64o2i = dnnl_OIw16i64o2i,
|
|
OIw16o16i2o = dnnl_OIw16o16i2o,
|
|
Owi16o = dnnl_Owi16o,
|
|
OwI16o2i = dnnl_OwI16o2i,
|
|
Iwo16i = dnnl_Iwo16i,
|
|
IwO16i2o = dnnl_IwO16i2o,
|
|
IwO16i4o = dnnl_IwO16i4o,
|
|
Owi4o = dnnl_Owi4o,
|
|
Owi8o = dnnl_Owi8o,
|
|
OwI8o2i = dnnl_OwI8o2i,
|
|
OwI8o4i = dnnl_OwI8o4i,
|
|
IOhw8o8i = dnnl_IOhw8o8i,
|
|
IOhw16o16i = dnnl_IOhw16o16i,
|
|
Ohwi16o = dnnl_Ohwi16o,
|
|
OhwI16o2i = dnnl_OhwI16o2i,
|
|
Ihwo16i = dnnl_Ihwo16i,
|
|
IhwO16i2o = dnnl_IhwO16i2o,
|
|
IhwO16i4o = dnnl_IhwO16i4o,
|
|
Ohwi4o = dnnl_Ohwi4o,
|
|
Ohwi8o = dnnl_Ohwi8o,
|
|
OhwI8o2i = dnnl_OhwI8o2i,
|
|
OhwI8o4i = dnnl_OhwI8o4i,
|
|
OIhw16i16o = dnnl_OIhw16i16o,
|
|
OhwI16i16o = dnnl_OhwI16i16o,
|
|
OIhw16i32o = dnnl_OIhw16i32o,
|
|
OhwI16i32o = dnnl_OhwI16i32o,
|
|
OIhw16i48o = dnnl_OIhw16i48o,
|
|
OhwI16i48o = dnnl_OhwI16i48o,
|
|
OIhw16i64o = dnnl_OIhw16i64o,
|
|
OhwI16i64o = dnnl_OhwI16i64o,
|
|
OIhw16o16i = dnnl_OIhw16o16i,
|
|
Oihw16o = dnnl_Oihw16o,
|
|
OIhw4i8o4i = dnnl_OIhw4i8o4i,
|
|
OhwI4i8o4i = dnnl_OhwI4i8o4i,
|
|
OIhw4i16o4i = dnnl_OIhw4i16o4i,
|
|
OhwI4i16o4i = dnnl_OhwI4i16o4i,
|
|
OIhw4i24o4i = dnnl_OIhw4i24o4i,
|
|
OhwI4i24o4i = dnnl_OhwI4i24o4i,
|
|
OIhw4i32o4i = dnnl_OIhw4i32o4i,
|
|
OhwI4i32o4i = dnnl_OhwI4i32o4i,
|
|
OIhw4i64o4i = dnnl_OIhw4i64o4i,
|
|
OhwI4i64o4i = dnnl_OhwI4i64o4i,
|
|
OIhw4i4o = dnnl_OIhw4i4o,
|
|
OIhw4o4i = dnnl_OIhw4o4i,
|
|
Oihw4o = dnnl_Oihw4o,
|
|
OIhw8i16o2i = dnnl_OIhw8i16o2i,
|
|
OhwI8i16o2i = dnnl_OhwI8i16o2i,
|
|
OIhw8i32o2i = dnnl_OIhw8i32o2i,
|
|
OhwI8i32o2i = dnnl_OhwI8i32o2i,
|
|
OIhw8i64o2i = dnnl_OIhw8i64o2i,
|
|
OhwI8i64o2i = dnnl_OhwI8i64o2i,
|
|
OIhw8i8o = dnnl_OIhw8i8o,
|
|
OhwI8i8o = dnnl_OhwI8i8o,
|
|
OIhw8o16i2o = dnnl_OIhw8o16i2o,
|
|
OIhw8o8i = dnnl_OIhw8o8i,
|
|
OIhw8o4i = dnnl_OIhw8o4i,
|
|
OIhw2i8o4i = dnnl_OIhw2i8o4i,
|
|
IOdhw8o8i = dnnl_IOdhw8o8i,
|
|
IOdhw16o16i = dnnl_IOdhw16o16i,
|
|
Odhwi16o = dnnl_Odhwi16o,
|
|
OdhwI16o2i = dnnl_OdhwI16o2i,
|
|
Idhwo16i = dnnl_Idhwo16i,
|
|
IdhwO16i2o = dnnl_IdhwO16i2o,
|
|
IdhwO16i4o = dnnl_IdhwO16i4o,
|
|
Odhwi4o = dnnl_Odhwi4o,
|
|
Odhwi8o = dnnl_Odhwi8o,
|
|
OdhwI8o2i = dnnl_OdhwI8o2i,
|
|
OdhwI8o4i = dnnl_OdhwI8o4i,
|
|
OIdhw16i16o = dnnl_OIdhw16i16o,
|
|
OdhwI16i16o = dnnl_OdhwI16i16o,
|
|
OIdhw16i32o = dnnl_OIdhw16i32o,
|
|
OdhwI16i32o = dnnl_OdhwI16i32o,
|
|
OIdhw16i48o = dnnl_OIdhw16i48o,
|
|
OdhwI16i48o = dnnl_OdhwI16i48o,
|
|
OIdhw16i64o = dnnl_OIdhw16i64o,
|
|
OdhwI16i64o = dnnl_OdhwI16i64o,
|
|
OIdhw16o16i = dnnl_OIdhw16o16i,
|
|
OIdhw16o16i2o = dnnl_OIdhw16o16i2o,
|
|
Oidhw16o = dnnl_Oidhw16o,
|
|
OIdhw4i4o = dnnl_OIdhw4i4o,
|
|
OIdhw4o4i = dnnl_OIdhw4o4i,
|
|
Oidhw4o = dnnl_Oidhw4o,
|
|
OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
|
|
OdhwI8i16o2i = dnnl_OdhwI8i16o2i,
|
|
OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
|
|
OdhwI8i32o2i = dnnl_OdhwI8i32o2i,
|
|
OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
|
|
OdhwI8i64o2i = dnnl_OdhwI8i64o2i,
|
|
OIdhw4i8o4i = dnnl_OIdhw4i8o4i,
|
|
OdhwI4i8o4i = dnnl_OdhwI4i8o4i,
|
|
OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
|
|
OdhwI4i16o4i = dnnl_OdhwI4i16o4i,
|
|
OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
|
|
OIdhw16i32o4i = dnnl_OIdhw16i32o4i,
|
|
OIdhw16i48o4i = dnnl_OIdhw16i48o4i,
|
|
OIdhw16i64o4i = dnnl_OIdhw16i64o4i,
|
|
OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
|
|
OIdhw16i32o2i = dnnl_OIdhw16i32o2i,
|
|
OIdhw16i48o2i = dnnl_OIdhw16i48o2i,
|
|
OIdhw16i64o2i = dnnl_OIdhw16i64o2i,
|
|
OIdhw4i24o4i = dnnl_OIdhw4i24o4i,
|
|
OdhwI4i24o4i = dnnl_OdhwI4i24o4i,
|
|
OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
|
|
OdhwI4i32o4i = dnnl_OdhwI4i32o4i,
|
|
OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
|
|
OdhwI4i64o4i = dnnl_OdhwI4i64o4i,
|
|
OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
|
|
OIdhw8i8o = dnnl_OIdhw8i8o,
|
|
OdhwI8i8o = dnnl_OdhwI8i8o,
|
|
OIdhw8o8i = dnnl_OIdhw8o8i,
|
|
OIdhw8o4i = dnnl_OIdhw8o4i,
|
|
gIOw8o8i = dnnl_gIOw8o8i,
|
|
gIOw16o16i = dnnl_gIOw16o16i,
|
|
gOIw16i16o = dnnl_gOIw16i16o,
|
|
gOIw16o16i = dnnl_gOIw16o16i,
|
|
gOiw16o = dnnl_gOiw16o,
|
|
gOIw4i16o4i = dnnl_gOIw4i16o4i,
|
|
gOIw2i8o4i = dnnl_gOIw2i8o4i,
|
|
gOIw4i4o = dnnl_gOIw4i4o,
|
|
gOIw4o4i = dnnl_gOIw4o4i,
|
|
gOiw4o = dnnl_gOiw4o,
|
|
gOIw8i16o2i = dnnl_gOIw8i16o2i,
|
|
gOIw8i8o = dnnl_gOIw8i8o,
|
|
gOIw8o16i2o = dnnl_gOIw8o16i2o,
|
|
gOIw8o8i = dnnl_gOIw8o8i,
|
|
gOIw8o4i = dnnl_gOIw8o4i,
|
|
gOIw16i16o4i = dnnl_gOIw16i16o4i,
|
|
gOIw16i16o2i = dnnl_gOIw16i16o2i,
|
|
gOIw16o16i2o = dnnl_gOIw16o16i2o,
|
|
gOwi16o = dnnl_gOwi16o,
|
|
gOwI16o2i = dnnl_gOwI16o2i,
|
|
gIwo16i = dnnl_gIwo16i,
|
|
gIwO16i2o = dnnl_gIwO16i2o,
|
|
gIwO16i4o = dnnl_gIwO16i4o,
|
|
gOwi4o = dnnl_gOwi4o,
|
|
gOwi8o = dnnl_gOwi8o,
|
|
gOwI8o2i = dnnl_gOwI8o2i,
|
|
gOwI8o4i = dnnl_gOwI8o4i,
|
|
Goiw8g = dnnl_Goiw8g,
|
|
Goiw16g = dnnl_Goiw16g,
|
|
gIOhw8o8i = dnnl_gIOhw8o8i,
|
|
gIOhw16o16i = dnnl_gIOhw16o16i,
|
|
gOhwi16o = dnnl_gOhwi16o,
|
|
gOhwI16o2i = dnnl_gOhwI16o2i,
|
|
gIhwo16i = dnnl_gIhwo16i,
|
|
gIhwO16i2o = dnnl_gIhwO16i2o,
|
|
gIhwO16i4o = dnnl_gIhwO16i4o,
|
|
gOhwi4o = dnnl_gOhwi4o,
|
|
gOhwi8o = dnnl_gOhwi8o,
|
|
gOhwI8o2i = dnnl_gOhwI8o2i,
|
|
gOhwI8o4i = dnnl_gOhwI8o4i,
|
|
Goihw16g = dnnl_Goihw16g,
|
|
gOIhw16i16o = dnnl_gOIhw16i16o,
|
|
gOIhw16o16i = dnnl_gOIhw16o16i,
|
|
gOihw16o = dnnl_gOihw16o,
|
|
gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
|
|
gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
|
|
gOIhw4i4o = dnnl_gOIhw4i4o,
|
|
gOIhw4o4i = dnnl_gOIhw4o4i,
|
|
gOihw4o = dnnl_gOihw4o,
|
|
Goihw8g = dnnl_Goihw8g,
|
|
gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
|
|
gOIhw8i8o = dnnl_gOIhw8i8o,
|
|
gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
|
|
OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
|
|
OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
|
|
OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
|
|
OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
|
|
gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
|
|
gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
|
|
gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
|
|
gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
|
|
OIhw16i16o4i = dnnl_OIhw16i16o4i,
|
|
OIhw16i32o4i = dnnl_OIhw16i32o4i,
|
|
OIhw16i48o4i = dnnl_OIhw16i48o4i,
|
|
OIhw16i64o4i = dnnl_OIhw16i64o4i,
|
|
OIhw16i16o2i = dnnl_OIhw16i16o2i,
|
|
OIhw16i32o2i = dnnl_OIhw16i32o2i,
|
|
OIhw16i48o2i = dnnl_OIhw16i48o2i,
|
|
OIhw16i64o2i = dnnl_OIhw16i64o2i,
|
|
OIhw16o16i2o = dnnl_OIhw16o16i2o,
|
|
gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
|
|
gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
|
|
gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
|
|
gOIhw8o8i = dnnl_gOIhw8o8i,
|
|
gOIhw8o4i = dnnl_gOIhw8o4i,
|
|
gIOdhw16i16o = dnnl_gIOdhw16i16o,
|
|
gIOdhw8o8i = dnnl_gIOdhw8o8i,
|
|
gIOdhw16o16i = dnnl_gIOdhw16o16i,
|
|
gOdhwi16o = dnnl_gOdhwi16o,
|
|
gOdhwI16o2i = dnnl_gOdhwI16o2i,
|
|
gIdhwo16i = dnnl_gIdhwo16i,
|
|
gIdhwO16i2o = dnnl_gIdhwO16i2o,
|
|
gIdhwO16i4o = dnnl_gIdhwO16i4o,
|
|
gOdhwi4o = dnnl_gOdhwi4o,
|
|
gOdhwi8o = dnnl_gOdhwi8o,
|
|
gOdhwI8o2i = dnnl_gOdhwI8o2i,
|
|
gOdhwI8o4i = dnnl_gOdhwI8o4i,
|
|
gOIdhw16i16o = dnnl_gOIdhw16i16o,
|
|
gOIdhw16o16i = dnnl_gOIdhw16o16i,
|
|
gOIdhw16o16i2o = dnnl_gOIdhw16o16i2o,
|
|
gOidhw16o = dnnl_gOidhw16o,
|
|
gOIdhw4i4o = dnnl_gOIdhw4i4o,
|
|
gOIdhw4o4i = dnnl_gOIdhw4o4i,
|
|
gOidhw4o = dnnl_gOidhw4o,
|
|
gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
|
|
gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
|
|
gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
|
|
gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
|
|
gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
|
|
gOIdhw8i8o = dnnl_gOIdhw8i8o,
|
|
gOIdhw8o8i = dnnl_gOIdhw8o8i,
|
|
gOIdhw8o4i = dnnl_gOIdhw8o4i,
|
|
gOIw2i4o2i = dnnl_gOIw2i4o2i,
|
|
gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
|
|
gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
|
|
gOIw2o4i2o = dnnl_gOIw2o4i2o,
|
|
gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
|
|
gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
|
|
gOIw4i8o2i = dnnl_gOIw4i8o2i,
|
|
gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
|
|
gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
|
|
gOIw4o8i2o = dnnl_gOIw4o8i2o,
|
|
gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
|
|
gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
|
|
|
|
ldOi16o = abDc16d,
|
|
ldOi32o = abDc32d,
|
|
ldOI16o4i = abDC16d4c,
|
|
ldOI32o4i = abDC32d4c,
|
|
ldgOi16o = abdEc16e,
|
|
ldgOI16o4i = abdEC16e4c,
|
|
ldgOi32o = abdEc32e,
|
|
ldgOI32o2i = abdEC32e2c,
|
|
ldgOI32o4i = abdEC32e4c,
|
|
OwI16o4i = dnnl_OwI16o4i,
|
|
OhwI16o4i = dnnl_OhwI16o4i,
|
|
gOwI16o4i = dnnl_gOwI16o4i,
|
|
gOhwI16o4i = dnnl_gOhwI16o4i,
|
|
OdhwI16o4i = dnnl_OdhwI16o4i,
|
|
gOdhwI16o4i = dnnl_gOdhwI16o4i,
|
|
|
|
Owi32o = dnnl_Owi32o,
|
|
OwI32o2i = dnnl_OwI32o2i,
|
|
OwI32o4i = dnnl_OwI32o4i,
|
|
Owi48o = dnnl_Owi48o,
|
|
OwI48o2i = dnnl_OwI48o2i,
|
|
OwI48o4i = dnnl_OwI48o4i,
|
|
Owi64o = dnnl_Owi64o,
|
|
OwI64o2i = dnnl_OwI64o2i,
|
|
OwI64o4i = dnnl_OwI64o4i,
|
|
Iwo32i = dnnl_Iwo32i,
|
|
IwO32i2o = dnnl_IwO32i2o,
|
|
IwO32i4o = dnnl_IwO32i4o,
|
|
Iwo48i = dnnl_Iwo48i,
|
|
IwO48i2o = dnnl_IwO48i2o,
|
|
IwO48i4o = dnnl_IwO48i4o,
|
|
Iwo64i = dnnl_Iwo64i,
|
|
IwO64i2o = dnnl_IwO64i2o,
|
|
IwO64i4o = dnnl_IwO64i4o,
|
|
wIo2i = dnnl_wIo2i,
|
|
wIo4i = dnnl_wIo4i,
|
|
gOwi32o = dnnl_gOwi32o,
|
|
gOwI32o2i = dnnl_gOwI32o2i,
|
|
gOwI32o4i = dnnl_gOwI32o4i,
|
|
gOwi48o = dnnl_gOwi48o,
|
|
gOwI48o2i = dnnl_gOwI48o2i,
|
|
gOwI48o4i = dnnl_gOwI48o4i,
|
|
gOwi64o = dnnl_gOwi64o,
|
|
gOwI64o2i = dnnl_gOwI64o2i,
|
|
gOwI64o4i = dnnl_gOwI64o4i,
|
|
gIwo32i = dnnl_gIwo32i,
|
|
gIwO32i2o = dnnl_gIwO32i2o,
|
|
gIwO32i4o = dnnl_gIwO32i4o,
|
|
gIwo48i = dnnl_gIwo48i,
|
|
gIwO48i2o = dnnl_gIwO48i2o,
|
|
gIwO48i4o = dnnl_gIwO48i4o,
|
|
gIwo64i = dnnl_gIwo64i,
|
|
gIwO64i2o = dnnl_gIwO64i2o,
|
|
gIwO64i4o = dnnl_gIwO64i4o,
|
|
gwio = dnnl_gwio,
|
|
gwIo2i = dnnl_gwIo2i,
|
|
gwIo4i = dnnl_gwIo4i,
|
|
OhwI32o = dnnl_OhwI32o,
|
|
OhwI32o2i = dnnl_OhwI32o2i,
|
|
OhwI32o4i = dnnl_OhwI32o4i,
|
|
Ohwi48o = dnnl_Ohwi48o,
|
|
OhwI48o2i = dnnl_OhwI48o2i,
|
|
OhwI48o4i = dnnl_OhwI48o4i,
|
|
Ohwi64o = dnnl_Ohwi64o,
|
|
OhwI64o2i = dnnl_OhwI64o2i,
|
|
OhwI64o4i = dnnl_OhwI64o4i,
|
|
Ihwo32i = dnnl_Ihwo32i,
|
|
IhwO32i2o = dnnl_IhwO32i2o,
|
|
IhwO32i4o = dnnl_IhwO32i4o,
|
|
Ihwo48i = dnnl_Ihwo48i,
|
|
IhwO48i2o = dnnl_IhwO48i2o,
|
|
IhwO48i4o = dnnl_IhwO48i4o,
|
|
Ihwo64i = dnnl_Ihwo64i,
|
|
IhwO64i2o = dnnl_IhwO64i2o,
|
|
IhwO64i4o = dnnl_IhwO64i4o,
|
|
hwIo2i = dnnl_hwIo2i,
|
|
hwIo4i = dnnl_hwIo4i,
|
|
gOhwI32o = dnnl_gOhwI32o,
|
|
gOhwI32o2i = dnnl_gOhwI32o2i,
|
|
gOhwI32o4i = dnnl_gOhwI32o4i,
|
|
gOhwi48o = dnnl_gOhwi48o,
|
|
gOhwI48o2i = dnnl_gOhwI48o2i,
|
|
gOhwI48o4i = dnnl_gOhwI48o4i,
|
|
gOhwi64o = dnnl_gOhwi64o,
|
|
gOhwI64o2i = dnnl_gOhwI64o2i,
|
|
gOhwI64o4i = dnnl_gOhwI64o4i,
|
|
gIhwo32i = dnnl_gIhwo32i,
|
|
gIhwO32i2o = dnnl_gIhwO32i2o,
|
|
gIhwO32i4o = dnnl_gIhwO32i4o,
|
|
gIhwo48i = dnnl_gIhwo48i,
|
|
gIhwO48i2o = dnnl_gIhwO48i2o,
|
|
gIhwO48i4o = dnnl_gIhwO48i4o,
|
|
gIhwo64i = dnnl_gIhwo64i,
|
|
gIhwO64i2o = dnnl_gIhwO64i2o,
|
|
gIhwO64i4o = dnnl_gIhwO64i4o,
|
|
ghwio = dnnl_ghwio,
|
|
ghwIo2i = dnnl_ghwIo2i,
|
|
ghwIo4i = dnnl_ghwIo4i,
|
|
Odhwi32o = dnnl_Odhwi32o,
|
|
OdhwI32o2i = dnnl_OdhwI32o2i,
|
|
OdhwI32o4i = dnnl_OdhwI32o4i,
|
|
Odhwi48o = dnnl_Odhwi48o,
|
|
OdhwI48o2i = dnnl_OdhwI48o2i,
|
|
OdhwI48o4i = dnnl_OdhwI48o4i,
|
|
Odhwi64o = dnnl_Odhwi64o,
|
|
OdhwI64o2i = dnnl_OdhwI64o2i,
|
|
OdhwI64o4i = dnnl_OdhwI64o4i,
|
|
Idhwo32i = dnnl_Idhwo32i,
|
|
IdhwO32i2o = dnnl_IdhwO32i2o,
|
|
IdhwO32i4o = dnnl_IdhwO32i4o,
|
|
Idhwo48i = dnnl_Idhwo48i,
|
|
IdhwO48i2o = dnnl_IdhwO48i2o,
|
|
IdhwO48i4o = dnnl_IdhwO48i4o,
|
|
Idhwo64i = dnnl_Idhwo64i,
|
|
IdhwO64i2o = dnnl_IdhwO64i2o,
|
|
IdhwO64i4o = dnnl_IdhwO64i4o,
|
|
dhwIo2i = dnnl_dhwIo2i,
|
|
dhwIo4i = dnnl_dhwIo4i,
|
|
gOdhwi32o = dnnl_gOdhwi32o,
|
|
gOdhwI32o2i = dnnl_gOdhwI32o2i,
|
|
gOdhwI32o4i = dnnl_gOdhwI32o4i,
|
|
gOdhwi48o = dnnl_gOdhwi48o,
|
|
gOdhwI48o2i = dnnl_gOdhwI48o2i,
|
|
gOdhwI48o4i = dnnl_gOdhwI48o4i,
|
|
gOdhwi64o = dnnl_gOdhwi64o,
|
|
gOdhwI64o2i = dnnl_gOdhwI64o2i,
|
|
gOdhwI64o4i = dnnl_gOdhwI64o4i,
|
|
gIdhwo32i = dnnl_gIdhwo32i,
|
|
gIdhwO32i2o = dnnl_gIdhwO32i2o,
|
|
gIdhwO32i4o = dnnl_gIdhwO32i4o,
|
|
gIdhwo48i = dnnl_gIdhwo48i,
|
|
gIdhwO48i2o = dnnl_gIdhwO48i2o,
|
|
gIdhwO48i4o = dnnl_gIdhwO48i4o,
|
|
gIdhwo64i = dnnl_gIdhwo64i,
|
|
gIdhwO64i2o = dnnl_gIdhwO64i2o,
|
|
gIdhwO64i4o = dnnl_gIdhwO64i4o,
|
|
gdhwio = dnnl_gdhwio,
|
|
gdhwIo2i = dnnl_gdhwIo2i,
|
|
gdhwIo4i = dnnl_gdhwIo4i,
|
|
ldIo32i = dnnl_ldIo32i,
|
|
ldgIo16i = dnnl_ldgIo16i,
|
|
ldgIo32i = dnnl_ldgIo32i,
|
|
ldgIO32i2o = dnnl_ldgIO32i2o,
|
|
nCdhw32c = dnnl_nCdhw32c,
|
|
nChw32c = dnnl_nChw32c,
|
|
nCw32c = dnnl_nCw32c,
|
|
NCw32n16c = dnnl_NCw32n16c,
|
|
NChw32n16c = dnnl_NChw32n16c,
|
|
NCdhw32n16c = dnnl_NCdhw32n16c,
|
|
NCw32n32c = dnnl_NCw32n32c,
|
|
OI16i16o4i = dnnl_OI16i16o4i,
|
|
IOw8o16i2o = dnnl_IOw8o16i2o,
|
|
IOhw8o16i2o = dnnl_IOhw8o16i2o,
|
|
Owhi16o = dnnl_Owhi16o,
|
|
OIdhw8o16i2o = dnnl_OIdhw8o16i2o,
|
|
IOdhw8o16i2o = dnnl_IOdhw8o16i2o,
|
|
Goiw4g = dnnl_Goiw4g,
|
|
gIOw8o16i2o = dnnl_gIOw8o16i2o,
|
|
Goiw32g = dnnl_Goiw32g,
|
|
Goihw4g = dnnl_Goihw4g,
|
|
gIOhw8o16i2o = dnnl_gIOhw8o16i2o,
|
|
Goihw32g = dnnl_Goihw32g,
|
|
gOwhi16o = dnnl_gOwhi16o,
|
|
IOw4i8o8i4o = dnnl_IOw4i8o8i4o,
|
|
IOhw4i8o8i4o = dnnl_IOhw4i8o8i4o,
|
|
IOdhw4i8o8i4o = dnnl_IOdhw4i8o8i4o,
|
|
gIOw4i8o8i4o = dnnl_gIOw4i8o8i4o,
|
|
gIOhw4i8o8i4o = dnnl_gIOhw4i8o8i4o,
|
|
gIOdhw4i8o8i4o = dnnl_gIOdhw4i8o8i4o,
|
|
gOIdhw8o16i2o = dnnl_gOIdhw8o16i2o,
|
|
gIOdhw8o16i2o = dnnl_gIOdhw8o16i2o,
|
|
Goidhw32g = dnnl_Goidhw32g,
|
|
OI16i32o4i = dnnl_OI16i32o4i,
|
|
OI16i48o4i = dnnl_OI16i48o4i,
|
|
OI16i64o4i = dnnl_OI16i64o4i,
|
|
OI16i16o2i = dnnl_OI16i16o2i,
|
|
OI16i32o2i = dnnl_OI16i32o2i,
|
|
OI16i48o2i = dnnl_OI16i48o2i,
|
|
OI16i64o2i = dnnl_OI16i64o2i,
|
|
aBdeC16c16b4c = dnnl_aBdeC16c16b4c,
|
|
AcB16b16a2b = dnnl_AcB16b16a2b,
|
|
aBdC16c16b2c = dnnl_aBdC16c16b2c,
|
|
AcB16b16a4b = dnnl_AcB16b16a4b,
|
|
aBdC16c16b4c = dnnl_aBdC16c16b4c,
|
|
AcdB16b16a2b = dnnl_AcdB16b16a2b,
|
|
aBdefC16c16b4c = dnnl_aBdefC16c16b4c,
|
|
AcdeB16b16a4b = dnnl_AcdeB16b16a4b,
|
|
AcB16b32a2b = dnnl_AcB16b32a2b,
|
|
AcB16b32a4b = dnnl_AcB16b32a4b,
|
|
AcB16b48a2b = dnnl_AcB16b48a2b,
|
|
AcB16b48a4b = dnnl_AcB16b48a4b,
|
|
AcB16b64a2b = dnnl_AcB16b64a2b,
|
|
AcB16b64a4b = dnnl_AcB16b64a4b,
|
|
aBdC16c32b2c = dnnl_aBdC16c32b2c,
|
|
aBdC16c32b4c = dnnl_aBdC16c32b4c,
|
|
aBdC16c48b2c = dnnl_aBdC16c48b2c,
|
|
aBdC16c48b4c = dnnl_aBdC16c48b4c,
|
|
aBdC16c64b2c = dnnl_aBdC16c64b2c,
|
|
aBdC16c64b4c = dnnl_aBdC16c64b4c,
|
|
AcdB16b32a2b = dnnl_AcdB16b32a2b,
|
|
AcdB16b32a4b = dnnl_AcdB16b32a4b,
|
|
AcdB16b48a2b = dnnl_AcdB16b48a2b,
|
|
AcdB16b48a4b = dnnl_AcdB16b48a4b,
|
|
AcdB16b64a2b = dnnl_AcdB16b64a2b,
|
|
AcdB16b64a4b = dnnl_AcdB16b64a4b,
|
|
aBdeC16c32b2c = dnnl_aBdeC16c32b2c,
|
|
aBdeC16c32b4c = dnnl_aBdeC16c32b4c,
|
|
aBdeC16c48b2c = dnnl_aBdeC16c48b2c,
|
|
aBdeC16c48b4c = dnnl_aBdeC16c48b4c,
|
|
aBdeC16c64b2c = dnnl_aBdeC16c64b2c,
|
|
aBdeC16c64b4c = dnnl_aBdeC16c64b4c,
|
|
AcdeB16b32a2b = dnnl_AcdeB16b32a2b,
|
|
AcdeB16b32a4b = dnnl_AcdeB16b32a4b,
|
|
AcdeB16b48a2b = dnnl_AcdeB16b48a2b,
|
|
AcdeB16b48a4b = dnnl_AcdeB16b48a4b,
|
|
AcdeB16b64a2b = dnnl_AcdeB16b64a2b,
|
|
AcdeB16b64a4b = dnnl_AcdeB16b64a4b,
|
|
aBdefC16c32b2c = dnnl_aBdefC16c32b2c,
|
|
aBdefC16c32b4c = dnnl_aBdefC16c32b4c,
|
|
aBdefC16c48b2c = dnnl_aBdefC16c48b2c,
|
|
aBdefC16c48b4c = dnnl_aBdefC16c48b4c,
|
|
aBdefC16c64b2c = dnnl_aBdefC16c64b2c,
|
|
aBdefC16c64b4c = dnnl_aBdefC16c64b4c,
|
|
OwI16i16o2i = dnnl_OwI16i16o2i,
|
|
gOwI16i16o2i = dnnl_gOwI16i16o2i,
|
|
OhwI16i16o2i = dnnl_OhwI16i16o2i,
|
|
gOhwI16i16o2i = dnnl_gOhwI16i16o2i,
|
|
OdhwI16i16o2i = dnnl_OdhwI16i16o2i,
|
|
gOdhwI16i16o2i = dnnl_gOdhwI16i16o2i,
|
|
OwI16i16o4i = dnnl_OwI16i16o4i,
|
|
gOwI16i16o4i = dnnl_gOwI16i16o4i,
|
|
OhwI16i16o4i = dnnl_OhwI16i16o4i,
|
|
gOhwI16i16o4i = dnnl_gOhwI16i16o4i,
|
|
OdhwI16i16o4i = dnnl_OdhwI16i16o4i,
|
|
gOdhwI16i16o4i = dnnl_gOdhwI16i16o4i,
|
|
OwI16i32o2i = dnnl_OwI16i32o2i,
|
|
OwI16i32o4i = dnnl_OwI16i32o4i,
|
|
OwI16i48o2i = dnnl_OwI16i48o2i,
|
|
OwI16i48o4i = dnnl_OwI16i48o4i,
|
|
OwI16i64o2i = dnnl_OwI16i64o2i,
|
|
OwI16i64o4i = dnnl_OwI16i64o4i,
|
|
gOwI16i32o2i = dnnl_gOwI16i32o2i,
|
|
gOwI16i32o4i = dnnl_gOwI16i32o4i,
|
|
gOwI16i48o2i = dnnl_gOwI16i48o2i,
|
|
gOwI16i48o4i = dnnl_gOwI16i48o4i,
|
|
gOwI16i64o2i = dnnl_gOwI16i64o2i,
|
|
gOwI16i64o4i = dnnl_gOwI16i64o4i,
|
|
OhwI16i32o2i = dnnl_OhwI16i32o2i,
|
|
OhwI16i32o4i = dnnl_OhwI16i32o4i,
|
|
OhwI16i48o2i = dnnl_OhwI16i48o2i,
|
|
OhwI16i48o4i = dnnl_OhwI16i48o4i,
|
|
OhwI16i64o2i = dnnl_OhwI16i64o2i,
|
|
OhwI16i64o4i = dnnl_OhwI16i64o4i,
|
|
gOhwI16i32o2i = dnnl_gOhwI16i32o2i,
|
|
gOhwI16i32o4i = dnnl_gOhwI16i32o4i,
|
|
gOhwI16i48o2i = dnnl_gOhwI16i48o2i,
|
|
gOhwI16i48o4i = dnnl_gOhwI16i48o4i,
|
|
gOhwI16i64o2i = dnnl_gOhwI16i64o2i,
|
|
gOhwI16i64o4i = dnnl_gOhwI16i64o4i,
|
|
OdhwI16i32o2i = dnnl_OdhwI16i32o2i,
|
|
OdhwI16i32o4i = dnnl_OdhwI16i32o4i,
|
|
OdhwI16i48o2i = dnnl_OdhwI16i48o2i,
|
|
OdhwI16i48o4i = dnnl_OdhwI16i48o4i,
|
|
OdhwI16i64o2i = dnnl_OdhwI16i64o2i,
|
|
OdhwI16i64o4i = dnnl_OdhwI16i64o4i,
|
|
IdhwO16o32i2o = dnnl_IdhwO16o32i2o,
|
|
IdhwO16o32i4o = dnnl_IdhwO16o32i4o,
|
|
IdhwO16o48i2o = dnnl_IdhwO16o48i2o,
|
|
IdhwO16o48i4o = dnnl_IdhwO16o48i4o,
|
|
IdhwO16o64i2o = dnnl_IdhwO16o64i2o,
|
|
IdhwO16o64i4o = dnnl_IdhwO16o64i4o,
|
|
gOdhwI16i32o2i = dnnl_gOdhwI16i32o2i,
|
|
gOdhwI16i32o4i = dnnl_gOdhwI16i32o4i,
|
|
gOdhwI16i48o2i = dnnl_gOdhwI16i48o2i,
|
|
gOdhwI16i48o4i = dnnl_gOdhwI16i48o4i,
|
|
gOdhwI16i64o2i = dnnl_gOdhwI16i64o2i,
|
|
gOdhwI16i64o4i = dnnl_gOdhwI16i64o4i,
|
|
gIdhwO16o32i2o = dnnl_gIdhwO16o32i2o,
|
|
gIdhwO16o32i4o = dnnl_gIdhwO16o32i4o,
|
|
gIdhwO16o48i2o = dnnl_gIdhwO16o48i2o,
|
|
gIdhwO16o48i4o = dnnl_gIdhwO16o48i4o,
|
|
gIdhwO16o64i2o = dnnl_gIdhwO16o64i2o,
|
|
gIdhwO16o64i4o = dnnl_gIdhwO16o64i4o,
|
|
IwO16o16i2o = dnnl_IwO16o16i2o,
|
|
IwO16o16i4o = dnnl_IwO16o16i4o,
|
|
IhwO16o16i2o = dnnl_IhwO16o16i2o,
|
|
IhwO16o16i4o = dnnl_IhwO16o16i4o,
|
|
IdhwO16o16i2o = dnnl_IdhwO16o16i2o,
|
|
IdhwO16o16i4o = dnnl_IdhwO16o16i4o,
|
|
gIwO16o16i2o = dnnl_gIwO16o16i2o,
|
|
gIwO16o16i4o = dnnl_gIwO16o16i4o,
|
|
gIhwO16o16i2o = dnnl_gIhwO16o16i2o,
|
|
gIhwO16o16i4o = dnnl_gIhwO16o16i4o,
|
|
gIdhwO16o16i2o = dnnl_gIdhwO16o16i2o,
|
|
gIdhwO16o16i4o = dnnl_gIdhwO16o16i4o,
|
|
IwO16o32i2o = dnnl_IwO16o32i2o,
|
|
IwO16o32i4o = dnnl_IwO16o32i4o,
|
|
IwO16o48i2o = dnnl_IwO16o48i2o,
|
|
IwO16o48i4o = dnnl_IwO16o48i4o,
|
|
IwO16o64i2o = dnnl_IwO16o64i2o,
|
|
IwO16o64i4o = dnnl_IwO16o64i4o,
|
|
gIwO16o32i2o = dnnl_gIwO16o32i2o,
|
|
gIwO16o32i4o = dnnl_gIwO16o32i4o,
|
|
gIwO16o48i2o = dnnl_gIwO16o48i2o,
|
|
gIwO16o48i4o = dnnl_gIwO16o48i4o,
|
|
gIwO16o64i2o = dnnl_gIwO16o64i2o,
|
|
gIwO16o64i4o = dnnl_gIwO16o64i4o,
|
|
IhwO16o32i2o = dnnl_IhwO16o32i2o,
|
|
IhwO16o32i4o = dnnl_IhwO16o32i4o,
|
|
IhwO16o48i2o = dnnl_IhwO16o48i2o,
|
|
IhwO16o48i4o = dnnl_IhwO16o48i4o,
|
|
IhwO16o64i2o = dnnl_IhwO16o64i2o,
|
|
IhwO16o64i4o = dnnl_IhwO16o64i4o,
|
|
gIhwO16o32i2o = dnnl_gIhwO16o32i2o,
|
|
gIhwO16o32i4o = dnnl_gIhwO16o32i4o,
|
|
gIhwO16o48i2o = dnnl_gIhwO16o48i2o,
|
|
gIhwO16o48i4o = dnnl_gIhwO16o48i4o,
|
|
gIhwO16o64i2o = dnnl_gIhwO16o64i2o,
|
|
gIhwO16o64i4o = dnnl_gIhwO16o64i4o,
|
|
aBdeC16c16b2c = dnnl_aBdeC16c16b2c,
|
|
aBdefC16c16b2c = dnnl_aBdefC16c16b2c,
|
|
AcdB16b16a4b = dnnl_AcdB16b16a4b,
|
|
AcdeB16b16a2b = dnnl_AcdeB16b16a2b,
|
|
hwioG16g = dnnl_hwioG16g,
|
|
hwioG8g = dnnl_hwioG8g,
|
|
dhwioG16g = dnnl_dhwioG16g,
|
|
dhwioG8g = dnnl_dhwioG8g,
|
|
ABc4a2b = dnnl_ABc4a2b,
|
|
ABc8a2b = dnnl_ABc8a2b,
|
|
ABcd4a2b = dnnl_ABcd4a2b,
|
|
ABcde4a2b = dnnl_ABcde4a2b,
|
|
ABcde8a2b = dnnl_ABcde8a2b,
|
|
ABcd4a8b8a2b = dnnl_ABcd4a8b8a2b,
|
|
NCdhw40n32c = dnnl_NCdhw40n32c,
|
|
NChw40n32c = dnnl_NChw40n32c,
|
|
NCw40n32c = dnnl_NCw40n32c,
|
|
OIdhw4o8i8o2i = dnnl_OIdhw4o8i8o2i,
|
|
OIhw4o8i8o2i = dnnl_OIhw4o8i8o2i,
|
|
OIw4o8i8o2i = dnnl_OIw4o8i8o2i,
|
|
gOIdhw4o8i8o2i = dnnl_gOIdhw4o8i8o2i,
|
|
gOIhw4o8i8o2i = dnnl_gOIhw4o8i8o2i,
|
|
gOIw4o8i8o2i = dnnl_gOIw4o8i8o2i,
|
|
IOdhw4i8o8i2o = dnnl_IOdhw4i8o8i2o,
|
|
IOhw4i8o8i2o = dnnl_IOhw4i8o8i2o,
|
|
IOw4i8o8i2o = dnnl_IOw4i8o8i2o,
|
|
gIOdhw4i8o8i2o = dnnl_gIOdhw4i8o8i2o,
|
|
gIOhw4i8o8i2o = dnnl_gIOhw4i8o8i2o,
|
|
gIOw4i8o8i2o = dnnl_gIOw4i8o8i2o,
|
|
aBCd8b2c = dnnl_aBCd8b2c,
|
|
ABcde40a16b = dnnl_ABcde40a16b,
|
|
ABcde40a32b = dnnl_ABcde40a32b,
|
|
aBCde8b2c = dnnl_aBCde8b2c,
|
|
ABcde4a8b8a2b = dnnl_ABcde4a8b8a2b,
|
|
ABc4a8b8a2b = dnnl_ABc4a8b8a2b,
|
|
aBCdef4b8c8b2c = dnnl_aBCdef4b8c8b2c,
|
|
aBCde4b8c8b2c = dnnl_aBCde4b8c8b2c,
|
|
aBCd4b8c8b2c = dnnl_aBCd4b8c8b2c,
|
|
BAcde4b8a8b2a = dnnl_BAcde4b8a8b2a,
|
|
BAcd4b8a8b2a = dnnl_BAcd4b8a8b2a,
|
|
BAc4b8a8b2a = dnnl_BAc4b8a8b2a,
|
|
aCBdef4c8b8c2b = dnnl_aCBdef4c8b8c2b,
|
|
aCBde4c8b8c2b = dnnl_aCBde4c8b8c2b,
|
|
aCBd4c8b8c2b = dnnl_aCBd4c8b8c2b,
|
|
aBCdef8b2c = dnnl_aBCdef8b2c,
|
|
AB32a16b = dnnl_AB32a16b,
|
|
AB32a32b = dnnl_AB32a32b,
|
|
BA4b8a8b2a = dnnl_BA4b8a8b2a,
|
|
BA4b8a8b4a = dnnl_BA4b8a8b4a,
|
|
aBC32b16c = dnnl_aBC32b16c,
|
|
aBC32b32c = dnnl_aBC32b32c,
|
|
aCB4c8b8c2b = dnnl_aCB4c8b8c2b,
|
|
aCB4c8b8c4b = dnnl_aCB4c8b8c4b,
|
|
ABc2b8a16b4a = dnnl_ABc2b8a16b4a,
|
|
ABcd2b8a16b4a = dnnl_ABcd2b8a16b4a,
|
|
ABcde2b8a16b4a = dnnl_ABcde2b8a16b4a,
|
|
ABc2a8b16a4b = dnnl_ABc2a8b16a4b,
|
|
ABc2a8b16a2b = dnnl_ABc2a8b16a2b,
|
|
ABc2b32a8b = dnnl_ABc2b32a8b,
|
|
ABcd2a8b16a4b = dnnl_ABcd2a8b16a4b,
|
|
ABcd2a8b16a2b = dnnl_ABcd2a8b16a2b,
|
|
aCBd2c8b16c2b = dnnl_aCBd2c8b16c2b,
|
|
ABcd2b32a8b = dnnl_ABcd2b32a8b,
|
|
aBCd2c8b16c2b = dnnl_aBCd2c8b16c2b,
|
|
ABcde2a8b16a4b = dnnl_ABcde2a8b16a4b,
|
|
ABcde2a8b16a2b = dnnl_ABcde2a8b16a2b,
|
|
aCBde2c8b16c2b = dnnl_aCBde2c8b16c2b,
|
|
ABcde2b32a8b = dnnl_ABcde2b32a8b,
|
|
aBC2b8c16b2c = dnnl_aBC2b8c16b2c,
|
|
aBCd2b8c16b2c = dnnl_aBCd2b8c16b2c,
|
|
aBCde2b8c16b2c = dnnl_aBCde2b8c16b2c,
|
|
aBCdef2b8c16b2c = dnnl_aBCdef2b8c16b2c,
|
|
BAcde2b8a16b4a = dnnl_BAcde2b8a16b4a,
|
|
BAcd2b8a16b4a = dnnl_BAcd2b8a16b4a,
|
|
BAc2b8a16b4a = dnnl_BAc2b8a16b4a,
|
|
BAcde2b8a16b2a = dnnl_BAcde2b8a16b2a,
|
|
BAcd2b8a16b2a = dnnl_BAcd2b8a16b2a,
|
|
BAc2b8a16b2a = dnnl_BAc2b8a16b2a,
|
|
aBCde2c8b16c2b = dnnl_aBCde2c8b16c2b,
|
|
aBCdef2c8b16c2b = dnnl_aBCdef2c8b16c2b,
|
|
aCBdef2c8b16c2b = dnnl_aCBdef2c8b16c2b,
|
|
aBCd2b8c16b4c = dnnl_aBCd2b8c16b4c,
|
|
aBCde2b8c16b4c = dnnl_aBCde2b8c16b4c,
|
|
NCdhw40n16c = dnnl_NCdhw40n16c,
|
|
NCw40n16c = dnnl_NCw40n16c,
|
|
NChw40n16c = dnnl_NChw40n16c,
|
|
NCw2c32n8c = dnnl_NCw2c32n8c,
|
|
NChw2c32n8c = dnnl_NChw2c32n8c,
|
|
NCdhw2c32n8c = dnnl_NCdhw2c32n8c,
|
|
OIw2i8o16i4o = dnnl_OIw2i8o16i4o,
|
|
OIhw2i8o16i4o = dnnl_OIhw2i8o16i4o,
|
|
OIdhw2i8o16i4o = dnnl_OIdhw2i8o16i4o,
|
|
OIw2o8i16o4i = dnnl_OIw2o8i16o4i,
|
|
OIw2o8i16o2i = dnnl_OIw2o8i16o2i,
|
|
IOw2i8o16i4o = dnnl_IOw2i8o16i4o,
|
|
IOw2i8o16i2o = dnnl_IOw2i8o16i2o,
|
|
OIhw2o8i16o4i = dnnl_OIhw2o8i16o4i,
|
|
OIhw2o8i16o2i = dnnl_OIhw2o8i16o2i,
|
|
IOhw2i8o16i4o = dnnl_IOhw2i8o16i4o,
|
|
IOhw2i8o16i2o = dnnl_IOhw2i8o16i2o,
|
|
OIdhw2o8i16o4i = dnnl_OIdhw2o8i16o4i,
|
|
OIdhw2o8i16o2i = dnnl_OIdhw2o8i16o2i,
|
|
IOdhw2i8o16i4o = dnnl_IOdhw2i8o16i4o,
|
|
IOdhw2i8o16i2o = dnnl_IOdhw2i8o16i2o,
|
|
gOIw2o8i16o2i = dnnl_gOIw2o8i16o2i,
|
|
gIOw2i8o16i2o = dnnl_gIOw2i8o16i2o,
|
|
gIOhw2i8o16i2o = dnnl_gIOhw2i8o16i2o,
|
|
gIOdhw2i8o16i2o = dnnl_gIOdhw2i8o16i2o,
|
|
gOIhw2o8i16o2i = dnnl_gOIhw2o8i16o2i,
|
|
gOIdhw2o8i16o2i = dnnl_gOIdhw2o8i16o2i,
|
|
gOIw2o8i16o4i = dnnl_gOIw2o8i16o4i,
|
|
gOIhw2o8i16o4i = dnnl_gOIhw2o8i16o4i,
|
|
BA4b8a16b2a = dnnl_BA4b8a16b2a,
|
|
BA4b8a16b4a = dnnl_BA4b8a16b4a,
|
|
aCB4c8b16c2b = dnnl_aCB4c8b16c2b,
|
|
aCB4c8b16c4b = dnnl_aCB4c8b16c4b,
|
|
aCB16c2b = dnnl_aCB16c2b,
|
|
aCB16c4b = dnnl_aCB16c4b,
|
|
BA16b2a = dnnl_BA16b2a,
|
|
BA16b4a = dnnl_BA16b4a,
|
|
BA4b4a = dnnl_BA4b4a,
|
|
BA8b4a = dnnl_BA8b4a,
|
|
aBC16b16c = dnnl_aBC16b16c,
|
|
aBC16b32c = dnnl_aBC16b32c,
|
|
AB16a16b = dnnl_AB16a16b,
|
|
AB16a32b = dnnl_AB16a32b,
|
|
ABcde16a16b2a = dnnl_ABcde16a16b2a,
|
|
aBCdef16b16c2b = dnnl_aBCdef16b16c2b,
|
|
Acedb16a = dnnl_Acedb16a,
|
|
aBdfec16b = dnnl_aBdfec16b,
|
|
Odwhi16o = dnnl_Odwhi16o,
|
|
gOdwhi16o = dnnl_gOdwhi16o,
|
|
abdEC64e2c = dnnl_abdEC64e2c,
|
|
abdEC64e4c = dnnl_abdEC64e4c,
|
|
ldgOI64o2i = abdEC64e2c,
|
|
ldgOI64o4i = abdEC64e4c,
|
|
abCd4c = dnnl_abCd4c,
|
|
abCde4c = dnnl_abCde4c,
|
|
abCdef4c = dnnl_abCdef4c,
|
|
abCde32c = dnnl_abCde32c,
|
|
abCdef32c = dnnl_abCdef32c,
|
|
aCdefB16b32c2b = dnnl_aCdefB16b32c2b,
|
|
aCdefB16b32c4b = dnnl_aCdefB16b32c4b,
|
|
aCdefB16b48c2b = dnnl_aCdefB16b48c2b,
|
|
aCdefB16b48c4b = dnnl_aCdefB16b48c4b,
|
|
aCdefB16b64c2b = dnnl_aCdefB16b64c2b,
|
|
aCdefB16b64c4b = dnnl_aCdefB16b64c4b,
|
|
BcdeA16a32b2a = dnnl_BcdeA16a32b2a,
|
|
BcdeA16a32b4a = dnnl_BcdeA16a32b4a,
|
|
BcdeA16a48b2a = dnnl_BcdeA16a48b2a,
|
|
BcdeA16a48b4a = dnnl_BcdeA16a48b4a,
|
|
BcdeA16a64b2a = dnnl_BcdeA16a64b2a,
|
|
BcdeA16a64b4a = dnnl_BcdeA16a64b4a,
|
|
aCdefb32c = dnnl_aCdefb32c,
|
|
aCdefB32c2b = dnnl_aCdefB32c2b,
|
|
aCdefB32c4b = dnnl_aCdefB32c4b,
|
|
aCdefb48c = dnnl_aCdefb48c,
|
|
aCdefB48c2b = dnnl_aCdefB48c2b,
|
|
aCdefB48c4b = dnnl_aCdefB48c4b,
|
|
aCdefb64c = dnnl_aCdefb64c,
|
|
aCdefB64c2b = dnnl_aCdefB64c2b,
|
|
aCdefB64c4b = dnnl_aCdefB64c4b,
|
|
Bcdea32b = dnnl_Bcdea32b,
|
|
BcdeA32b2a = dnnl_BcdeA32b2a,
|
|
BcdeA32b4a = dnnl_BcdeA32b4a,
|
|
Bcdea48b = dnnl_Bcdea48b,
|
|
BcdeA48b2a = dnnl_BcdeA48b2a,
|
|
BcdeA48b4a = dnnl_BcdeA48b4a,
|
|
Bcdea64b = dnnl_Bcdea64b,
|
|
BcdeA64b2a = dnnl_BcdeA64b2a,
|
|
BcdeA64b4a = dnnl_BcdeA64b4a,
|
|
Bca32b = dnnl_Bca32b,
|
|
BcA32b2a = dnnl_BcA32b2a,
|
|
BcA32b4a = dnnl_BcA32b4a,
|
|
Bca48b = dnnl_Bca48b,
|
|
BcA48b2a = dnnl_BcA48b2a,
|
|
BcA48b4a = dnnl_BcA48b4a,
|
|
Bca64b = dnnl_Bca64b,
|
|
BcA64b2a = dnnl_BcA64b2a,
|
|
BcA64b4a = dnnl_BcA64b4a,
|
|
aCdb32c = dnnl_aCdb32c,
|
|
aCdB32c2b = dnnl_aCdB32c2b,
|
|
aCdB32c4b = dnnl_aCdB32c4b,
|
|
aCdb48c = dnnl_aCdb48c,
|
|
aCdB48c2b = dnnl_aCdB48c2b,
|
|
aCdB48c4b = dnnl_aCdB48c4b,
|
|
aCdb64c = dnnl_aCdb64c,
|
|
aCdB64c2b = dnnl_aCdB64c2b,
|
|
aCdB64c4b = dnnl_aCdB64c4b,
|
|
BcA16a16b2a = dnnl_BcA16a16b2a,
|
|
BcA16a16b4a = dnnl_BcA16a16b4a,
|
|
BcdA16a16b2a = dnnl_BcdA16a16b2a,
|
|
BcdA16a16b4a = dnnl_BcdA16a16b4a,
|
|
BcdeA16a16b2a = dnnl_BcdeA16a16b2a,
|
|
BcdeA16a16b4a = dnnl_BcdeA16a16b4a,
|
|
aCdB16b16c2b = dnnl_aCdB16b16c2b,
|
|
aCdB16b16c4b = dnnl_aCdB16b16c4b,
|
|
aCdeB16b16c2b = dnnl_aCdeB16b16c2b,
|
|
aCdeB16b16c4b = dnnl_aCdeB16b16c4b,
|
|
aCdefB16b16c2b = dnnl_aCdefB16b16c2b,
|
|
aCdefB16b16c4b = dnnl_aCdefB16b16c4b,
|
|
BcA16a32b2a = dnnl_BcA16a32b2a,
|
|
BcA16a32b4a = dnnl_BcA16a32b4a,
|
|
BcA16a48b2a = dnnl_BcA16a48b2a,
|
|
BcA16a48b4a = dnnl_BcA16a48b4a,
|
|
BcA16a64b2a = dnnl_BcA16a64b2a,
|
|
BcA16a64b4a = dnnl_BcA16a64b4a,
|
|
aCdB16b32c2b = dnnl_aCdB16b32c2b,
|
|
aCdB16b32c4b = dnnl_aCdB16b32c4b,
|
|
aCdB16b48c2b = dnnl_aCdB16b48c2b,
|
|
aCdB16b48c4b = dnnl_aCdB16b48c4b,
|
|
aCdB16b64c2b = dnnl_aCdB16b64c2b,
|
|
aCdB16b64c4b = dnnl_aCdB16b64c4b,
|
|
BcdA16a32b2a = dnnl_BcdA16a32b2a,
|
|
BcdA16a32b4a = dnnl_BcdA16a32b4a,
|
|
BcdA16a48b2a = dnnl_BcdA16a48b2a,
|
|
BcdA16a48b4a = dnnl_BcdA16a48b4a,
|
|
BcdA16a64b2a = dnnl_BcdA16a64b2a,
|
|
BcdA16a64b4a = dnnl_BcdA16a64b4a,
|
|
aCdeB16b32c2b = dnnl_aCdeB16b32c2b,
|
|
aCdeB16b32c4b = dnnl_aCdeB16b32c4b,
|
|
aCdeB16b48c2b = dnnl_aCdeB16b48c2b,
|
|
aCdeB16b48c4b = dnnl_aCdeB16b48c4b,
|
|
aCdeB16b64c2b = dnnl_aCdeB16b64c2b,
|
|
aCdeB16b64c4b = dnnl_aCdeB16b64c4b,
|
|
Bca16b = dnnl_Bca16b,
|
|
BcA16b2a = dnnl_BcA16b2a,
|
|
BcA16b4a = dnnl_BcA16b4a,
|
|
Bcda16b = dnnl_Bcda16b,
|
|
BcdA16b2a = dnnl_BcdA16b2a,
|
|
BcdA16b4a = dnnl_BcdA16b4a,
|
|
Bcdea16b = dnnl_Bcdea16b,
|
|
BcdeA16b2a = dnnl_BcdeA16b2a,
|
|
BcdeA16b4a = dnnl_BcdeA16b4a,
|
|
aCdb16c = dnnl_aCdb16c,
|
|
aCdB16c2b = dnnl_aCdB16c2b,
|
|
aCdB16c4b = dnnl_aCdB16c4b,
|
|
aCdeb16c = dnnl_aCdeb16c,
|
|
aCdeB16c2b = dnnl_aCdeB16c2b,
|
|
aCdeB16c4b = dnnl_aCdeB16c4b,
|
|
aCdefb16c = dnnl_aCdefb16c,
|
|
aCdefB16c2b = dnnl_aCdefB16c2b,
|
|
aCdefB16c4b = dnnl_aCdefB16c4b,
|
|
Bcda32b = dnnl_Bcda32b,
|
|
BcdA32b2a = dnnl_BcdA32b2a,
|
|
BcdA32b4a = dnnl_BcdA32b4a,
|
|
Bcda48b = dnnl_Bcda48b,
|
|
BcdA48b2a = dnnl_BcdA48b2a,
|
|
BcdA48b4a = dnnl_BcdA48b4a,
|
|
Bcda64b = dnnl_Bcda64b,
|
|
BcdA64b2a = dnnl_BcdA64b2a,
|
|
BcdA64b4a = dnnl_BcdA64b4a,
|
|
aCdeb32c = dnnl_aCdeb32c,
|
|
aCdeB32c2b = dnnl_aCdeB32c2b,
|
|
aCdeB32c4b = dnnl_aCdeB32c4b,
|
|
aCdeb48c = dnnl_aCdeb48c,
|
|
aCdeB48c2b = dnnl_aCdeB48c2b,
|
|
aCdeB48c4b = dnnl_aCdeB48c4b,
|
|
aCdeb64c = dnnl_aCdeb64c,
|
|
aCdeB64c2b = dnnl_aCdeB64c2b,
|
|
aCdeB64c4b = dnnl_aCdeB64c4b,
|
|
NChw16n32c = dnnl_NChw16n32c,
|
|
goIw4i = dnnl_goIw4i,
|
|
goIw32i = dnnl_goIw32i,
|
|
goIhw4i = dnnl_goIhw4i,
|
|
goIhw32i = dnnl_goIhw32i,
|
|
goIdhw4i = dnnl_goIdhw4i,
|
|
goIdhw32i = dnnl_goIdhw32i,
|
|
cab = dnnl_cab,
|
|
cdab = dnnl_cdab,
|
|
cdeab = dnnl_cdeab,
|
|
woi = dnnl_woi,
|
|
hwoi = dnnl_hwoi,
|
|
dhwoi = dnnl_dhwoi,
|
|
Owi24o = dnnl_Owi24o,
|
|
Ohwi24o = dnnl_Ohwi24o,
|
|
Odhwi24o = dnnl_Odhwi24o,
|
|
gOwi24o = dnnl_gOwi24o,
|
|
gOhwi24o = dnnl_gOhwi24o,
|
|
gOdhwi24o = dnnl_gOdhwi24o,
|
|
OwI24o2i = dnnl_OwI24o2i,
|
|
OhwI24o2i = dnnl_OhwI24o2i,
|
|
OdhwI24o2i = dnnl_OdhwI24o2i,
|
|
gOwI24o2i = dnnl_gOwI24o2i,
|
|
gOhwI24o2i = dnnl_gOhwI24o2i,
|
|
gOdhwI24o2i = dnnl_gOdhwI24o2i,
|
|
OwI24o4i = dnnl_OwI24o4i,
|
|
OhwI24o4i = dnnl_OhwI24o4i,
|
|
OdhwI24o4i = dnnl_OdhwI24o4i,
|
|
gOwI24o4i = dnnl_gOwI24o4i,
|
|
gOhwI24o4i = dnnl_gOhwI24o4i,
|
|
gOdhwI24o4i = dnnl_gOdhwI24o4i,
|
|
OI8i32o = dnnl_OI8i32o,
|
|
OIw8i32o = dnnl_OIw8i32o,
|
|
OwI8i32o = dnnl_OwI8i32o,
|
|
OIhw8i32o = dnnl_OIhw8i32o,
|
|
OhwI8i32o = dnnl_OhwI8i32o,
|
|
OIdhw8i32o = dnnl_OIdhw8i32o,
|
|
OdhwI8i32o = dnnl_OdhwI8i32o,
|
|
OI8i24o = dnnl_OI8i24o,
|
|
OIw8i24o = dnnl_OIw8i24o,
|
|
OwI8i24o = dnnl_OwI8i24o,
|
|
OIhw8i24o = dnnl_OIhw8i24o,
|
|
OhwI8i24o = dnnl_OhwI8i24o,
|
|
OIdhw8i24o = dnnl_OIdhw8i24o,
|
|
OdhwI8i24o = dnnl_OdhwI8i24o,
|
|
OI8i16o = dnnl_OI8i16o,
|
|
OIw8i16o = dnnl_OIw8i16o,
|
|
OwI8i16o = dnnl_OwI8i16o,
|
|
OIhw8i16o = dnnl_OIhw8i16o,
|
|
OhwI8i16o = dnnl_OhwI8i16o,
|
|
OIdhw8i16o = dnnl_OIdhw8i16o,
|
|
OdhwI8i16o = dnnl_OdhwI8i16o,
|
|
OI8i8o = dnnl_OI8i8o,
|
|
AB4b8a4b = dnnl_AB4b8a4b,
|
|
AB4b24a4b = dnnl_AB4b24a4b,
|
|
ABc4b8a4b = dnnl_ABc4b8a4b,
|
|
AcB4b8a4b = dnnl_AcB4b8a4b,
|
|
ABc4b24a4b = dnnl_ABc4b24a4b,
|
|
AcB4b24a4b = dnnl_AcB4b24a4b,
|
|
ABcd4b8a4b = dnnl_ABcd4b8a4b,
|
|
AcdB4b8a4b = dnnl_AcdB4b8a4b,
|
|
ABcd4b24a4b = dnnl_ABcd4b24a4b,
|
|
AcdB4b24a4b = dnnl_AcdB4b24a4b,
|
|
ABcde4b8a4b = dnnl_ABcde4b8a4b,
|
|
AcdeB4b8a4b = dnnl_AcdeB4b8a4b,
|
|
ABcde4b24a4b = dnnl_ABcde4b24a4b,
|
|
AcdeB4b24a4b = dnnl_AcdeB4b24a4b,
|
|
Bca8b = dnnl_Bca8b,
|
|
BcA8b2a = dnnl_BcA8b2a,
|
|
Bcda8b = dnnl_Bcda8b,
|
|
BcdA8b2a = dnnl_BcdA8b2a,
|
|
Bcdea8b = dnnl_Bcdea8b,
|
|
BcdeA8b2a = dnnl_BcdeA8b2a,
|
|
aCdb8c = dnnl_aCdb8c,
|
|
aCdB8c2b = dnnl_aCdB8c2b,
|
|
aCdeb8c = dnnl_aCdeb8c,
|
|
aCdeB8c2b = dnnl_aCdeB8c2b,
|
|
aCdefb8c = dnnl_aCdefb8c,
|
|
aCdefB8c2b = dnnl_aCdefB8c2b,
|
|
Bca24b = dnnl_Bca24b,
|
|
BcA24b2a = dnnl_BcA24b2a,
|
|
Bcda24b = dnnl_Bcda24b,
|
|
BcdA24b2a = dnnl_BcdA24b2a,
|
|
Bcdea24b = dnnl_Bcdea24b,
|
|
BcdeA24b2a = dnnl_BcdeA24b2a,
|
|
aCdb24c = dnnl_aCdb24c,
|
|
aCdB24c2b = dnnl_aCdB24c2b,
|
|
aCdeb24c = dnnl_aCdeb24c,
|
|
aCdeB24c2b = dnnl_aCdeB24c2b,
|
|
aCdefb24c = dnnl_aCdefb24c,
|
|
aCdefB24c2b = dnnl_aCdefB24c2b,
|
|
Iwo8i = dnnl_Iwo8i,
|
|
IwO8i2o = dnnl_IwO8i2o,
|
|
Iwo24i = dnnl_Iwo24i,
|
|
IwO24i2o = dnnl_IwO24i2o,
|
|
Ihwo8i = dnnl_Ihwo8i,
|
|
IhwO8i2o = dnnl_IhwO8i2o,
|
|
Ihwo24i = dnnl_Ihwo24i,
|
|
IhwO24i2o = dnnl_IhwO24i2o,
|
|
Idhwo8i = dnnl_Idhwo8i,
|
|
IdhwO8i2o = dnnl_IdhwO8i2o,
|
|
Idhwo24i = dnnl_Idhwo24i,
|
|
IdhwO24i2o = dnnl_IdhwO24i2o,
|
|
gIwo8i = dnnl_gIwo8i,
|
|
gIwO8i2o = dnnl_gIwO8i2o,
|
|
gIwo24i = dnnl_gIwo24i,
|
|
gIwO24i2o = dnnl_gIwO24i2o,
|
|
gIhwo8i = dnnl_gIhwo8i,
|
|
gIhwO8i2o = dnnl_gIhwO8i2o,
|
|
gIhwo24i = dnnl_gIhwo24i,
|
|
gIhwO24i2o = dnnl_gIhwO24i2o,
|
|
gIdhwo8i = dnnl_gIdhwo8i,
|
|
gIdhwO8i2o = dnnl_gIdhwO8i2o,
|
|
gIdhwo24i = dnnl_gIdhwo24i,
|
|
gIdhwO24i2o = dnnl_gIdhwO24i2o,
|
|
OhwI24o = dnnl_OhwI24o,
|
|
gOhwI24o = dnnl_gOhwI24o,
|
|
AB8b24a2b = dnnl_AB8b24a2b,
|
|
ABc8b24a2b = dnnl_ABc8b24a2b,
|
|
AcB8b24a2b = dnnl_AcB8b24a2b,
|
|
ABcd8b24a2b = dnnl_ABcd8b24a2b,
|
|
AcdB8b24a2b = dnnl_AcdB8b24a2b,
|
|
ABcde8b24a2b = dnnl_ABcde8b24a2b,
|
|
AcdeB8b24a2b = dnnl_AcdeB8b24a2b,
|
|
AB8b8a2b = dnnl_AB8b8a2b,
|
|
ABc8b8a2b = dnnl_ABc8b8a2b,
|
|
AcB8b8a2b = dnnl_AcB8b8a2b,
|
|
ABcd8b8a2b = dnnl_ABcd8b8a2b,
|
|
AcdB8b8a2b = dnnl_AcdB8b8a2b,
|
|
ABcde8b8a2b = dnnl_ABcde8b8a2b,
|
|
AcdeB8b8a2b = dnnl_AcdeB8b8a2b,
|
|
OI8i8o2i = dnnl_OI8i8o2i,
|
|
OI8i24o2i = dnnl_OI8i24o2i,
|
|
OIw8i8o2i = dnnl_OIw8i8o2i,
|
|
OwI8i8o2i = dnnl_OwI8i8o2i,
|
|
OIw8i24o2i = dnnl_OIw8i24o2i,
|
|
OwI8i24o2i = dnnl_OwI8i24o2i,
|
|
OIhw8i8o2i = dnnl_OIhw8i8o2i,
|
|
OhwI8i8o2i = dnnl_OhwI8i8o2i,
|
|
OIhw8i24o2i = dnnl_OIhw8i24o2i,
|
|
OhwI8i24o2i = dnnl_OhwI8i24o2i,
|
|
OIdhw8i8o2i = dnnl_OIdhw8i8o2i,
|
|
OdhwI8i8o2i = dnnl_OdhwI8i8o2i,
|
|
OIdhw8i24o2i = dnnl_OIdhw8i24o2i,
|
|
OdhwI8i24o2i = dnnl_OdhwI8i24o2i,
|
|
BcA8b4a = dnnl_BcA8b4a,
|
|
BcdA8b4a = dnnl_BcdA8b4a,
|
|
BcdeA8b4a = dnnl_BcdeA8b4a,
|
|
aCdB8c4b = dnnl_aCdB8c4b,
|
|
aCdeB8c4b = dnnl_aCdeB8c4b,
|
|
aCdefB8c4b = dnnl_aCdefB8c4b,
|
|
BcA24b4a = dnnl_BcA24b4a,
|
|
BcdA24b4a = dnnl_BcdA24b4a,
|
|
BcdeA24b4a = dnnl_BcdeA24b4a,
|
|
aCdB24c4b = dnnl_aCdB24c4b,
|
|
aCdeB24c4b = dnnl_aCdeB24c4b,
|
|
aCdefB24c4b = dnnl_aCdefB24c4b,
|
|
ABc16a4b = dnnl_ABc16a4b,
|
|
ABcd16a4b = dnnl_ABcd16a4b,
|
|
ABcde16a4b = dnnl_ABcde16a4b,
|
|
IwO8i4o = dnnl_IwO8i4o,
|
|
IwO24i4o = dnnl_IwO24i4o,
|
|
IhwO8i4o = dnnl_IhwO8i4o,
|
|
IhwO24i4o = dnnl_IhwO24i4o,
|
|
IdhwO8i4o = dnnl_IdhwO8i4o,
|
|
IdhwO24i4o = dnnl_IdhwO24i4o,
|
|
gIwO8i4o = dnnl_gIwO8i4o,
|
|
gIwO24i4o = dnnl_gIwO24i4o,
|
|
gIhwO8i4o = dnnl_gIhwO8i4o,
|
|
gIhwO24i4o = dnnl_gIhwO24i4o,
|
|
gIdhwO8i4o = dnnl_gIdhwO8i4o,
|
|
gIdhwO24i4o = dnnl_gIdhwO24i4o,
|
|
BA2a24b = dnnl_BA2a24b,
|
|
aCB2b24c = dnnl_aCB2b24c,
|
|
BA2a8b = dnnl_BA2a8b,
|
|
aCB2b8c = dnnl_aCB2b8c,
|
|
BA8a24b = dnnl_BA8a24b,
|
|
aCB8b24c = dnnl_aCB8b24c,
|
|
BA8a16b = dnnl_BA8a16b,
|
|
aCB8b16c = dnnl_aCB8b16c,
|
|
BA8a8b = dnnl_BA8a8b,
|
|
aCB8b8c = dnnl_aCB8b8c,
|
|
bcad = dnnl_bcad,
|
|
cabd = dnnl_cabd,
|
|
dabc = dnnl_dabc,
|
|
decbA4a = dnnl_decbA4a,
|
|
defcbA4a = dnnl_defcbA4a,
|
|
hwioG4g = dnnl_hwioG4g,
|
|
dhwioG4g = dnnl_dhwioG4g,
|
|
};
|
|
|
|
/// A memory descriptor.
|
|
struct desc : public handle<dnnl_memory_desc_t> {
|
|
using handle<dnnl_memory_desc_t>::handle;
|
|
|
|
friend struct memory;
|
|
|
|
/// Constructs a zero (empty) memory descriptor. Such a memory
|
|
/// descriptor can be used to indicate absence of an argument.
|
|
desc() {
|
|
dnnl_memory_desc_t zero_md = nullptr;
|
|
error::wrap_c_api(
|
|
dnnl_memory_desc_create_with_tag(&zero_md, 0, nullptr,
|
|
dnnl_data_type_undef, dnnl_format_tag_undef),
|
|
"could not create a zero memory descriptor");
|
|
reset(zero_md);
|
|
}
|
|
|
|
/// Constructs a memory descriptor.
|
|
///
|
|
/// @note
|
|
/// The logical order of dimensions corresponds to the `abc...`
|
|
/// format tag, and the physical meaning of the dimensions depends
|
|
/// both on the primitive that would operate on this memory and
|
|
/// the operation context.
|
|
///
|
|
/// @param adims Tensor dimensions.
|
|
/// @param adata_type Data precision/type.
|
|
/// @param aformat_tag Memory format tag.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be constructed. This flag is
|
|
/// optional and defaults to false.
|
|
desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
|
|
bool allow_empty = false) {
|
|
validate_dims(adims);
|
|
dnnl_memory_desc_t md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_with_tag(&md,
|
|
(int)adims.size(), adims.data(), convert_to_c(adata_type),
|
|
convert_to_c(aformat_tag));
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not construct a memory descriptor using a "
|
|
"format tag");
|
|
reset(md);
|
|
}
|
|
|
|
/// Constructs a memory descriptor by strides.
|
|
///
|
|
/// @note
|
|
/// The logical order of dimensions corresponds to the `abc...`
|
|
/// format tag, and the physical meaning of the dimensions depends
|
|
/// both on the primitive that would operate on this memory and
|
|
/// the operation context.
|
|
///
|
|
/// @param adims Tensor dimensions.
|
|
/// @param adata_type Data precision/type.
|
|
/// @param strides Strides for each dimension.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be constructed. This flag is
|
|
/// optional and defaults to false.
|
|
desc(const dims &adims, data_type adata_type, const dims &strides,
|
|
bool allow_empty = false) {
|
|
validate_dims(adims);
|
|
if (!strides.empty()) validate_dims(strides, (int)adims.size());
|
|
dnnl_memory_desc_t md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_with_strides(&md,
|
|
(int)adims.size(), adims.data(), convert_to_c(adata_type),
|
|
strides.empty() ? nullptr : &strides[0]);
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not construct a memory descriptor using "
|
|
"strides");
|
|
reset(md);
|
|
}
|
|
|
|
/// Function for creating a memory descriptor for CSR sparse encoding.
|
|
///
|
|
/// The created memory descriptor will describe a memory object that
|
|
/// contains 3 buffers. The buffers have the following meaning and
|
|
/// assigned numbers (index):
|
|
/// - 0: values
|
|
/// - 1: indices
|
|
/// - 2: pointers
|
|
///
|
|
/// @param adims Tensor dimensions.
|
|
/// @param adata_type Data precision/type.
|
|
/// @param nnz Number of non-zero entries.
|
|
/// @param index_dt Data type of indices.
|
|
/// @param pointer_dt Data type of pointers.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be constructed. This flag is
|
|
/// optional and defaults to false.
|
|
/// @sa @ref dev_guide_sparsity
|
|
static desc csr(const dims &adims, data_type adata_type, dim nnz,
|
|
data_type index_dt, data_type pointer_dt,
|
|
bool allow_empty = false) {
|
|
validate_dims(adims);
|
|
dnnl_memory_desc_t md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_with_csr_encoding(
|
|
&md, (int)adims.size(), adims.data(),
|
|
convert_to_c(adata_type), nnz, convert_to_c(index_dt),
|
|
convert_to_c(pointer_dt));
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a memory descriptor for CSR sparse "
|
|
"encoding");
|
|
return desc {md};
|
|
}
|
|
|
|
/// Function for creating a memory descriptor for COO sparse encodings.
|
|
///
|
|
/// The created memory descriptor will describe a memory object that
|
|
/// contains n+1 buffers for an n-dimensional tensor.
|
|
/// The buffers have the following meaning and assigned numbers (index):
|
|
/// - 0: values
|
|
/// - 1: indices for dimension 0
|
|
/// - 2: indices for dimension 1 ...
|
|
/// - n: indices for dimension n-1
|
|
///
|
|
/// @param adims Tensor dimensions.
|
|
/// @param adata_type Data precision/type.
|
|
/// @param nnz Number of non-zero entries.
|
|
/// @param index_dt Data type of indices.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be constructed. This flag is
|
|
/// optional and defaults to false.
|
|
/// @sa @ref dev_guide_sparsity
|
|
static desc coo(const dims &adims, data_type adata_type, dim nnz,
|
|
data_type index_dt, bool allow_empty = false) {
|
|
validate_dims(adims);
|
|
dnnl_memory_desc_t md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_with_coo_encoding(
|
|
&md, (int)adims.size(), adims.data(),
|
|
convert_to_c(adata_type), nnz, convert_to_c(index_dt));
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a memory descriptor for COO sparse "
|
|
"encoding");
|
|
return desc {md};
|
|
}
|
|
|
|
/// Function for creating a memory descriptor for packed sparse
|
|
/// encoding.
|
|
///
|
|
/// The created memory descriptor cannot be used to create a memory
|
|
/// object. It can only be used to create a primitive descriptor to
|
|
/// query the actual memory descriptor (similar to the format tag
|
|
/// `any`).
|
|
///
|
|
/// @warning
|
|
/// The meaning and content of the handles of the memory object that
|
|
/// is created using the queried memory descriptor are unspecified
|
|
/// therefore using the content is an undefined behavior.
|
|
///
|
|
/// @param adims Tensor dimensions.
|
|
/// @param adata_type Data precision/type.
|
|
/// @param nnz Number of non-zero entries.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be constructed. This flag is
|
|
/// optional and defaults to false.
|
|
/// @sa @ref dev_guide_sparsity
|
|
static desc packed(const dims &adims, data_type adata_type, dim nnz,
|
|
bool allow_empty = false) {
|
|
validate_dims(adims);
|
|
dnnl_memory_desc_t md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_with_packed_encoding(
|
|
&md, (int)adims.size(), adims.data(),
|
|
convert_to_c(adata_type), nnz);
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a memory descriptor for packed "
|
|
"sparse encoding");
|
|
return desc {md};
|
|
}
|
|
|
|
/// Creates a memory descriptor for a scalar value that resides on the host.
|
|
///
|
|
/// @param adata_type Data type of the scalar.
|
|
/// @returns A memory descriptor for host-side scalar input.
|
|
static desc host_scalar(data_type adata_type) {
|
|
dnnl_memory_desc_t md = nullptr;
|
|
error::wrap_c_api(dnnl_memory_desc_create_host_scalar(
|
|
&md, convert_to_c(adata_type)),
|
|
"could not create a memory descriptor describing host side "
|
|
"scalar");
|
|
return desc {md};
|
|
}
|
|
|
|
/// Construct a memory descriptor from a C API ::dnnl_memory_desc_t
|
|
/// handle. The resulting handle is not weak and the C handle will be
|
|
/// destroyed during the destruction of the C++ object.
|
|
///
|
|
/// @param md The C API memory descriptor.
|
|
desc(dnnl_memory_desc_t md) : handle<dnnl_memory_desc_t>(md) {}
|
|
|
|
/// Construct a memory descriptor from a binary blob.
|
|
///
|
|
/// @param blob A binary blob previously queried from a memory descriptor.
|
|
desc(const std::vector<uint8_t> &blob) {
|
|
dnnl_memory_desc_t md = nullptr;
|
|
error::wrap_c_api(
|
|
dnnl_memory_desc_create_with_blob(&md, blob.data()),
|
|
"could not create a memory descriptor from blob");
|
|
reset(md);
|
|
}
|
|
|
|
/// Constructs a memory descriptor for a region inside an area
|
|
/// described by this memory descriptor.
|
|
//
|
|
/// @param adims Sizes of the region.
|
|
/// @param offsets Offsets to the region from the encompassing
|
|
/// memory object in each dimension.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be returned. This flag is optional
|
|
/// and defaults to false.
|
|
/// @returns A memory descriptor for the region.
|
|
desc submemory_desc(const dims &adims, const dims &offsets,
|
|
bool allow_empty = false) const {
|
|
validate_dims(adims, get_ndims());
|
|
validate_dims(offsets, get_ndims());
|
|
dnnl_memory_desc_t sub_md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_create_submemory(
|
|
&sub_md, get(), adims.data(), offsets.data());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status, "could not construct a sub-memory");
|
|
return desc(sub_md);
|
|
}
|
|
|
|
/// Constructs a memory descriptor by reshaping an existing one. The
|
|
/// new memory descriptor inherits the data type. This operation is
|
|
/// valid only for memory descriptors that have format_kind set to
|
|
/// #dnnl::memory::format_kind::blocked or
|
|
/// #dnnl::memory::format_kind::any.
|
|
///
|
|
/// The operation ensures that the transformation of the physical memory
|
|
/// format corresponds to the transformation of the logical dimensions.
|
|
/// If such transformation is impossible, the function either throws an
|
|
/// exception (default) or returns a zero memory descriptor depending on
|
|
/// the `allow_empty` flag.
|
|
///
|
|
/// The reshape operation can be described as a combination of the
|
|
/// following basic operations:
|
|
/// 1. Add a dimension of size `1`. This is always possible.
|
|
/// 2. Remove a dimension of size `1`. This is possible only if the
|
|
/// dimension has no padding (i.e.
|
|
/// `padded_dims[dim] == dims[dim] && dims[dim] == 1`).
|
|
/// 3. Split a dimension into multiple ones. This is possible only if
|
|
/// the product of all tensor dimensions stays constant and the
|
|
/// dimension being split does not have padding (i.e.
|
|
/// `padded_dims[dim] = dims[dim]`).
|
|
/// 4. Join multiple consecutive dimensions into a single one. As in
|
|
/// the cases above, this requires that the dimensions do not have
|
|
/// padding and that the memory format is such that in physical
|
|
/// memory these dimensions are dense and have the same order as
|
|
/// their logical counterparts. This also assumes that these
|
|
/// dimensions are not blocked.
|
|
/// - Here, 'dense' means:
|
|
/// `stride for dim[i] == (stride for dim[i + 1]) * dim[i + 1]`;
|
|
/// - And 'same order' means:
|
|
/// `i < j` if and only if `stride for dim[j] <= stride for dim[i]`.
|
|
///
|
|
/// @warning
|
|
/// Some combinations of physical memory layout and/or offsets or
|
|
/// dimensions may result in a failure to make a reshape.
|
|
///
|
|
/// @param adims New dimensions. The product of dimensions must
|
|
/// remain constant.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be returned. This flag is optional
|
|
/// and defaults to false.
|
|
/// @returns A new memory descriptor with new dimensions.
|
|
desc reshape(const dims &adims, bool allow_empty = false) const {
|
|
if (get_ndims()) validate_dims(adims, 1);
|
|
dnnl_memory_desc_t out_md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_reshape(
|
|
&out_md, get(), (int)adims.size(), adims.data());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(
|
|
status, "could not reshape a memory descriptor");
|
|
return desc(out_md);
|
|
}
|
|
|
|
/// Constructs a memory descriptor by permuting axes in an existing
|
|
/// one.
|
|
///
|
|
/// The physical memory layout representation is adjusted accordingly
|
|
/// to maintain the consistency between the logical and physical parts
|
|
/// of the memory descriptor. The new memory descriptor inherits the
|
|
/// data type.
|
|
///
|
|
/// The new memory descriptor inherits the data type. This operation is
|
|
/// valid only for memory descriptors that have format_kind set to
|
|
/// #dnnl::memory::format_kind::blocked or
|
|
/// #dnnl::memory::format_kind::any.
|
|
///
|
|
/// The logical axes will be permuted in the following manner:
|
|
/// @code
|
|
/// for (i = 0; i < get_ndims(); i++)
|
|
/// new_desc.dims()[permutation[i]] = dims()[i];
|
|
/// @endcode
|
|
///
|
|
/// Example:
|
|
/// @code
|
|
/// std::vector<int> permutation = {1, 0}; // swap the first and
|
|
/// // the second axes
|
|
/// dnnl::memory::desc in_md(
|
|
/// {2, 3}, data_type, memory::format_tag::ab);
|
|
/// dnnl::memory::desc expect_out_md(
|
|
/// {3, 2}, data_type, memory::format_tag::ba);
|
|
///
|
|
/// assert(in_md.permute_axes(permutation) == expect_out_md);
|
|
/// @endcode
|
|
///
|
|
/// @param permutation Axes permutation.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case a
|
|
/// zero memory descriptor will be returned. This flag is optional
|
|
/// and defaults to false.
|
|
/// @returns A new memory descriptor with new dimensions.
|
|
desc permute_axes(const std::vector<int> &permutation,
|
|
bool allow_empty = false) const {
|
|
validate_dims(permutation, get_ndims());
|
|
dnnl_memory_desc_t out_md = nullptr;
|
|
dnnl_status_t status = dnnl_memory_desc_permute_axes(
|
|
&out_md, get(), permutation.data());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not permute axes of a memory descriptor");
|
|
return desc(out_md);
|
|
}
|
|
|
|
/// Returns a number of dimensions of the memory descriptor.
|
|
///
|
|
/// @returns A number of dimensions.
|
|
int get_ndims() const { return query_s32(query::ndims_s32); }
|
|
|
|
/// Returns padded dimensions of the memory descriptor.
|
|
///
|
|
/// @returns A copy of the padded dimensions vector.
|
|
memory::dims get_padded_dims() const {
|
|
return query_dims(query::padded_dims);
|
|
}
|
|
|
|
/// Returns padded offsets of the memory descriptor.
|
|
///
|
|
/// @returns A copy of the padded offsets vector.
|
|
memory::dims get_padded_offsets() const {
|
|
return query_dims(query::padded_offsets);
|
|
}
|
|
|
|
/// Returns a submemory offset of the memory descriptor.
|
|
///
|
|
/// @returns A submemory offset.
|
|
memory::dim get_submemory_offset() const {
|
|
dnnl_dim_t submemory_offset;
|
|
dnnl_status_t status = dnnl_memory_desc_query(
|
|
get(), dnnl_query_submemory_offset_s64, &submemory_offset);
|
|
return status == dnnl_success ? submemory_offset : 0;
|
|
}
|
|
|
|
/// Returns strides of the memory descriptor.
|
|
///
|
|
/// @note
|
|
/// This API is only applicable to memory descriptors with format
|
|
/// kind #dnnl_blocked.
|
|
///
|
|
/// @returns A copy of the strides vector.
|
|
/// @returns An empty #dnnl::memory::dims if the memory descriptor
|
|
/// does not have strides.
|
|
memory::dims get_strides() const { return query_dims(query::strides); }
|
|
|
|
/// Returns a number of inner blocks of the memory descriptor.
|
|
///
|
|
/// @note
|
|
/// This API is only applicable to memory descriptors with format
|
|
/// kind #dnnl_blocked.
|
|
///
|
|
/// @returns A number of inner blocks.
|
|
int get_inner_nblks() const {
|
|
return query_s32(query::inner_nblks_s32);
|
|
}
|
|
|
|
/// Returns inner blocks of the memory descriptor.
|
|
///
|
|
/// @note
|
|
/// This API is only applicable to memory descriptors with format
|
|
/// kind #dnnl_blocked.
|
|
///
|
|
/// @returns A copy of the inner blocks vector.
|
|
/// @returns An empty #dnnl::memory::dims if the memory descriptor
|
|
/// does not have inner blocks.
|
|
memory::dims get_inner_blks() const {
|
|
return query_dims(query::inner_blks);
|
|
}
|
|
|
|
/// Returns inner indices of the memory descriptor.
|
|
///
|
|
/// @note
|
|
/// This API is only applicable to memory descriptors with format
|
|
/// kind #dnnl_blocked.
|
|
///
|
|
/// @returns A copy of the inner indices vector.
|
|
/// @returns An empty #dnnl::memory::dims if the memory descriptor
|
|
/// does not have inner indices.
|
|
memory::dims get_inner_idxs() const {
|
|
return query_dims(query::inner_idxs);
|
|
}
|
|
|
|
/// Returns number of handles.
|
|
///
|
|
/// @returns A number of handles.
|
|
int get_num_handles() const {
|
|
int nhandles;
|
|
dnnl_status_t status = dnnl_memory_desc_query_v2(
|
|
get(), dnnl_query_num_handles_s32, 0, &nhandles);
|
|
return status == dnnl_success ? nhandles : 0;
|
|
}
|
|
|
|
/// Returns a number of non-zero entries of the memory descriptor.
|
|
///
|
|
/// @returns A number non-zero entries.
|
|
dim get_nnz() const {
|
|
dnnl_dim_t nnz;
|
|
dnnl_status_t status = dnnl_memory_desc_query_v2(
|
|
get(), dnnl_query_nnz_s64, 0, &nnz);
|
|
return status == dnnl_success ? nnz : 0;
|
|
}
|
|
|
|
/// Returns the sparse encoding of the memory descriptor.
|
|
///
|
|
/// @returns the sparse encoding kind.
|
|
/// @sa @ref dev_guide_sparsity
|
|
memory::sparse_encoding get_sparse_encoding() const {
|
|
dnnl_sparse_encoding_t sparse_encoding;
|
|
dnnl_status_t status = dnnl_memory_desc_query_v2(
|
|
get(), dnnl_query_sparse_encoding, 0, &sparse_encoding);
|
|
return status == dnnl_success
|
|
? static_cast<dnnl::memory::sparse_encoding>(
|
|
sparse_encoding)
|
|
: dnnl::memory::sparse_encoding::undef;
|
|
}
|
|
|
|
/// Returns the data type of the memory descriptor.
|
|
///
|
|
/// @returns The data type.
|
|
memory::data_type get_data_type(int index = 0) const {
|
|
return query_data_type(query::data_type, index);
|
|
}
|
|
|
|
/// Returns the format kind of the memory descriptor.
|
|
///
|
|
/// @returns the format kind.
|
|
memory::format_kind get_format_kind() const {
|
|
dnnl_format_kind_t format_kind;
|
|
dnnl_status_t status = dnnl_memory_desc_query(
|
|
get(), dnnl_query_format_kind, &format_kind);
|
|
return status == dnnl_success
|
|
? static_cast<dnnl::memory::format_kind>(format_kind)
|
|
: dnnl::memory::format_kind::undef;
|
|
}
|
|
|
|
/// Returns dimensions of the memory descriptor.
|
|
///
|
|
/// Potentially expensive due to the data copy involved.
|
|
/// @returns A copy of the dimensions vector.
|
|
memory::dims get_dims() const { return query_dims(query::dims); }
|
|
|
|
/// Returns size of the memory descriptor in bytes.
|
|
/// @param index Data index. Defaults to 0.
|
|
/// @returns The number of bytes required to allocate a memory buffer
|
|
/// for data with a particular @p index described by this memory
|
|
/// descriptor including the padding area.
|
|
size_t get_size(int index = 0) const {
|
|
return dnnl_memory_desc_get_size_v2(get(), index);
|
|
}
|
|
|
|
/// Returns a binary blob associated with the given memory descriptor
|
|
/// @returns The memory descriptor blob associated with the memory descriptor
|
|
std::vector<uint8_t> get_blob() {
|
|
size_t size;
|
|
dnnl_status_t status
|
|
= dnnl_memory_desc_get_blob(nullptr, &size, get());
|
|
error::wrap_c_api(
|
|
status, "could not get memory descriptor blob size");
|
|
|
|
std::vector<uint8_t> out_blob(size);
|
|
status = dnnl_memory_desc_get_blob(out_blob.data(), &size, get());
|
|
error::wrap_c_api(status, "could not get memory descriptor blob");
|
|
return out_blob;
|
|
}
|
|
|
|
/// Checks whether the memory descriptor is zero (empty).
|
|
/// @returns @c true if the memory descriptor describes an empty
|
|
/// memory and @c false otherwise.
|
|
bool is_zero() const { return get_ndims() == 0; }
|
|
|
|
/// An equality operator.
|
|
/// @param other Another memory descriptor.
|
|
/// @returns Whether this and the other memory descriptors have
|
|
/// the same format tag, dimensions, strides, blocking, etc.
|
|
bool operator==(const desc &other) const {
|
|
return dnnl_memory_desc_equal(get(), other.get()) != 0;
|
|
}
|
|
|
|
/// An inequality operator.
|
|
/// @param other Another memory descriptor.
|
|
/// @returns Whether this and the other memory descriptors describe
|
|
/// different memory.
|
|
bool operator!=(const desc &other) const { return !operator==(other); }
|
|
|
|
private:
|
|
memory::data_type query_data_type(query what, int index) const {
|
|
dnnl_data_type_t data_type;
|
|
dnnl_status_t status = dnnl_memory_desc_query_v2(
|
|
get(), dnnl::convert_to_c(what), index, &data_type);
|
|
return status == dnnl_success
|
|
? static_cast<dnnl::memory::data_type>(data_type)
|
|
: dnnl::memory::data_type::undef;
|
|
}
|
|
|
|
int query_s32(query what) const {
|
|
int res;
|
|
dnnl_status_t status = dnnl_memory_desc_query(
|
|
get(), dnnl::convert_to_c(what), &res);
|
|
return status == dnnl_success ? res : 0;
|
|
}
|
|
|
|
memory::dims query_dims(query what) const {
|
|
dnnl_dims_t *c_dims;
|
|
dnnl_status_t status = dnnl_memory_desc_query(
|
|
get(), dnnl::convert_to_c(what), &c_dims);
|
|
|
|
const int ndims
|
|
= (what == query::inner_idxs || what == query::inner_blks)
|
|
? get_inner_nblks()
|
|
: get_ndims();
|
|
|
|
return status == dnnl_success
|
|
? memory::dims(*c_dims, *c_dims + ndims)
|
|
: memory::dims {};
|
|
}
|
|
};
|
|
|
|
/// Default constructor.
|
|
///
|
|
/// Constructs an empty memory object, which can be used to indicate
|
|
/// absence of a parameter.
|
|
memory() = default;
|
|
|
|
/// Constructs a memory object.
|
|
///
|
|
/// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
|
|
/// object will have the underlying buffer set. In this case, the buffer
|
|
/// will be initialized as if #dnnl::memory::set_data_handle() had been
|
|
/// called.
|
|
///
|
|
/// @sa memory::set_data_handle()
|
|
///
|
|
/// @param md Memory descriptor.
|
|
/// @param aengine Engine to store the data on.
|
|
/// @param handle Handle of the memory buffer to use.
|
|
/// - A pointer to the user-allocated buffer. In this case the library
|
|
/// doesn't own the buffer.
|
|
/// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
|
/// allocate the buffer for the memory object. In this case the
|
|
/// library owns the buffer.
|
|
/// - #DNNL_MEMORY_NONE to create dnnl::memory without an underlying
|
|
/// buffer.
|
|
memory(const desc &md, const engine &aengine, void *handle)
|
|
: memory(md, aengine, std::vector<void *> {handle}) {}
|
|
|
|
/// Constructs a memory object with multiple handles.
|
|
///
|
|
/// Unless @p handle is equal to #DNNL_MEMORY_NONE, the constructed memory
|
|
/// object will have the underlying buffer set. In this case, the buffer
|
|
/// will be initialized as if #dnnl::memory::set_data_handle() had been
|
|
/// called.
|
|
///
|
|
/// @sa memory::set_data_handle()
|
|
///
|
|
/// @param md Memory descriptor.
|
|
/// @param aengine Engine to store the data on.
|
|
/// @param handles Handles of the memory buffers to use.
|
|
/// For each element of the @p handles vector the following applies:
|
|
/// - A pointer to the user-allocated buffer. In this case the library
|
|
/// doesn't own the buffer.
|
|
/// - The #DNNL_MEMORY_ALLOCATE special value. Instructs the library to
|
|
/// allocate the buffer for the memory object. In this case the
|
|
/// library owns the buffer.
|
|
/// - #DNNL_MEMORY_NONE Instructs the library to skip allocation of the
|
|
/// memory buffer.
|
|
memory(const desc &md, const engine &aengine, std::vector<void *> handles) {
|
|
dnnl_memory_t result;
|
|
dnnl_status_t status = dnnl_memory_create_v2(&result, md.get(),
|
|
aengine.get(), (int)handles.size(), handles.data());
|
|
error::wrap_c_api(status, "could not create a memory object");
|
|
reset(result);
|
|
}
|
|
|
|
/// Constructs a memory object.
|
|
///
|
|
/// The underlying buffer(s) for the memory will be allocated by the
|
|
/// library.
|
|
/// @param md Memory descriptor.
|
|
/// @param aengine Engine to store the data on.
|
|
memory(const desc &md, const engine &aengine) {
|
|
dnnl_status_t status;
|
|
dnnl_memory_t result;
|
|
const int nhandles = md.get_num_handles();
|
|
|
|
std::vector<void *> handles(nhandles, DNNL_MEMORY_ALLOCATE);
|
|
status = dnnl_memory_create_v2(&result, md.get(), aengine.get(),
|
|
(int)handles.size(), handles.data());
|
|
|
|
error::wrap_c_api(status, "could not create a memory object");
|
|
reset(result);
|
|
}
|
|
|
|
/// Constructs a memory object that wraps a host-side scalar value.
|
|
///
|
|
/// @note The scalar value is copied into the newly allocated memory storage,
|
|
/// so the user does not need to manage the lifetime of the original scalar data.
|
|
///
|
|
/// @tparam T Type of the scalar value.
|
|
/// @param md Memory descriptor describing a scalar value residing on the host.
|
|
/// @param value The scalar value to be wrapped by the memory object.
|
|
///
|
|
/// @throws error if the memory object could not be created.
|
|
template <typename T>
|
|
memory(const desc &md, const T value) {
|
|
dnnl_memory_t result;
|
|
// Check that the data type of T matches the memory descriptor's data type
|
|
// For host-side scalars, md.get_size() is data_type size
|
|
if (sizeof(T) != md.get_size()) {
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"scalar type size does not match memory descriptor data "
|
|
"type size");
|
|
} else {
|
|
dnnl_status_t status = dnnl_memory_create_host_scalar(
|
|
&result, md.get(), (void *)&value);
|
|
error::wrap_c_api(status, "could not create a memory object");
|
|
}
|
|
reset(result);
|
|
}
|
|
|
|
/// Returns the associated memory descriptor.
|
|
desc get_desc() const {
|
|
const_dnnl_memory_desc_t cdesc;
|
|
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
|
|
"could not get a memory descriptor from a memory object");
|
|
dnnl_memory_desc_t cloned_md = nullptr;
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
|
|
"could not clone a memory descriptor");
|
|
return desc(cloned_md);
|
|
}
|
|
|
|
/// Returns the associated engine.
|
|
engine get_engine() const {
|
|
dnnl_engine_t c_engine;
|
|
error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
|
|
"could not get an engine from a memory object");
|
|
return engine(c_engine, true);
|
|
}
|
|
|
|
/// Returns an underlying memory buffer that corresponds to the given index.
|
|
///
|
|
/// On the CPU engine, or when using USM, this is a pointer to the
|
|
/// allocated memory.
|
|
void *get_data_handle(int index = 0) const {
|
|
void *handle;
|
|
error::wrap_c_api(dnnl_memory_get_data_handle_v2(get(), &handle, index),
|
|
"could not get a native handle from a memory object");
|
|
return handle;
|
|
}
|
|
|
|
/// Sets an underlying memory buffer that corresponds to the given index.
|
|
///
|
|
/// @param handle Memory buffer to use. On the CPU engine or when USM is
|
|
/// used, the memory buffer is a pointer to the actual data. For OpenCL
|
|
/// it is a cl_mem. It must have at least
|
|
/// #dnnl::memory::desc::get_size() bytes allocated.
|
|
/// @param index Memory index to attach the buffer. Defaults to 0.
|
|
void set_data_handle(void *handle, int index = 0) const {
|
|
error::wrap_c_api(dnnl_memory_set_data_handle_v2(get(), handle, index),
|
|
"could not set native handle of a memory object");
|
|
}
|
|
|
|
/// Returns the scalar value stored in the memory object as type T.
|
|
///
|
|
/// @tparam T Type to cast the scalar value to.
|
|
template <typename T>
|
|
T get_host_scalar_value() const {
|
|
const_dnnl_memory_desc_t cdesc;
|
|
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
|
|
"could not get memory descriptor");
|
|
|
|
if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"scalar type size does not match memory descriptor data "
|
|
"type size");
|
|
}
|
|
|
|
T value;
|
|
error::wrap_c_api(dnnl_memory_get_host_scalar_value(get(), &value),
|
|
"could not get host scalar value from a memory object");
|
|
return value;
|
|
}
|
|
|
|
/// Sets the scalar value stored in the memory object.
|
|
///
|
|
/// @note The scalar value is copied into the memory storage, so the user
|
|
/// does not need to manage the lifetime of the original scalar data.
|
|
///
|
|
/// @param value Pointer to the scalar value to set.
|
|
template <typename T>
|
|
void set_host_scalar_value(const T value) const {
|
|
const_dnnl_memory_desc_t cdesc;
|
|
error::wrap_c_api(dnnl_memory_get_memory_desc(get(), &cdesc),
|
|
"could not get memory descriptor from a memory object");
|
|
|
|
if (sizeof(T) != dnnl_memory_desc_get_size_v2(cdesc, 0)) {
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"scalar type size does not match memory descriptor data "
|
|
"type size");
|
|
}
|
|
|
|
error::wrap_c_api(dnnl_memory_set_host_scalar_value(get(), &value),
|
|
"could not set host scalar value to a memory object");
|
|
}
|
|
|
|
/// Maps a memory object and returns a host-side pointer to a memory
|
|
/// buffer with a copy of its contents. The memory buffer corresponds to
|
|
/// the given index.
|
|
///
|
|
/// Mapping enables read/write directly from/to the memory contents for
|
|
/// engines that do not support direct memory access.
|
|
///
|
|
/// Mapping is an exclusive operation - a memory object cannot be used in
|
|
/// other operations until it is unmapped via #dnnl::memory::unmap_data()
|
|
/// call.
|
|
///
|
|
/// @note
|
|
/// Any primitives working with the memory should be completed before
|
|
/// the memory is mapped. Use #dnnl::stream::wait() to synchronize the
|
|
/// corresponding execution stream.
|
|
///
|
|
/// @note
|
|
/// The map_data and unmap_data functions are provided mainly for
|
|
/// debug and testing purposes and their performance may be suboptimal.
|
|
///
|
|
/// @tparam T Data type to return a pointer to.
|
|
/// @param index Index of the buffer. Defaults to 0.
|
|
/// @returns Pointer to the mapped memory.
|
|
template <typename T = void>
|
|
T *map_data(int index = 0) const {
|
|
void *mapped_ptr;
|
|
error::wrap_c_api(dnnl_memory_map_data_v2(get(), &mapped_ptr, index),
|
|
"could not map memory object data");
|
|
return static_cast<T *>(mapped_ptr);
|
|
}
|
|
|
|
/// Unmaps a memory object and writes back any changes made to the
|
|
/// previously mapped memory buffer. The memory buffer corresponds to
|
|
/// the given index.
|
|
///
|
|
/// @note
|
|
/// The map_data and unmap_data functions are provided mainly for
|
|
/// debug and testing purposes and their performance may be
|
|
/// suboptimal.
|
|
///
|
|
/// @param mapped_ptr A pointer previously returned by
|
|
/// #dnnl::memory::map_data().
|
|
/// @param index Index of the buffer. Defaults to 0.
|
|
void unmap_data(void *mapped_ptr, int index = 0) const {
|
|
error::wrap_c_api(dnnl_memory_unmap_data_v2(get(), mapped_ptr, index),
|
|
"could not unmap memory object data");
|
|
}
|
|
|
|
static dnnl_data_type_t convert_to_c(data_type adata_type) {
|
|
return static_cast<dnnl_data_type_t>(adata_type);
|
|
}
|
|
static dnnl_format_tag_t convert_to_c(format_tag format) {
|
|
return static_cast<dnnl_format_tag_t>(format);
|
|
}
|
|
};
|
|
|
|
inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
|
|
return a == memory::convert_to_c(b);
|
|
}
|
|
inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
|
|
return !(a == b);
|
|
}
|
|
inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
|
|
return b == a;
|
|
}
|
|
inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
|
|
return !(a == b);
|
|
}
|
|
|
|
inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
|
|
return a == memory::convert_to_c(b);
|
|
}
|
|
inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
|
|
return !(a == b);
|
|
}
|
|
inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
|
|
return b == a;
|
|
}
|
|
inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
|
|
return !(a == b);
|
|
}
|
|
|
|
/// @} dnnl_api_memory
|
|
|
|
/// @addtogroup dnnl_api_primitives
|
|
/// @{
|
|
/// @addtogroup dnnl_api_attributes Attributes
|
|
///
|
|
/// A container for parameters that extend primitives behavior.
|
|
///
|
|
/// @{
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
template <>
|
|
struct handle_traits<dnnl_post_ops_t> {
|
|
static dnnl_status_t destructor(dnnl_post_ops_t p) {
|
|
return dnnl_post_ops_destroy(p);
|
|
}
|
|
};
|
|
/// @endcond
|
|
|
|
/// Post-ops.
|
|
///
|
|
/// Post-ops are computations executed after the main primitive computations
|
|
/// and are attached to the primitive via primitive attributes.
|
|
///
|
|
/// @sa @ref dev_guide_attributes_post_ops
|
|
///
|
|
struct post_ops : public handle<dnnl_post_ops_t> {
|
|
using handle<dnnl_post_ops_t>::handle;
|
|
|
|
/// Constructs an empty sequence of post-ops.
|
|
post_ops() {
|
|
dnnl_post_ops_t result;
|
|
error::wrap_c_api(
|
|
dnnl_post_ops_create(&result), "could not create post-ops");
|
|
reset(result);
|
|
}
|
|
|
|
/// Creates post-ops primitive attribute from a C API ::dnnl_post_ops_t
|
|
/// handle. The resulting handle is not weak and the C handle will be
|
|
/// destroyed during the destruction of the C++ object.
|
|
///
|
|
/// @param post_ops The C API post-ops primitive attribute.
|
|
post_ops(dnnl_post_ops_t post_ops) : handle<dnnl_post_ops_t>(post_ops) {}
|
|
|
|
/// Returns the number of post-ops entries.
|
|
int len() const { return dnnl_post_ops_len(get()); }
|
|
|
|
/// Returns the primitive kind of post-op at entry with a certain index.
|
|
/// @param index Index of the post-op to return the kind for.
|
|
/// @returns Primitive kind of the post-op at the specified index.
|
|
primitive::kind kind(int index) const {
|
|
error::wrap_c_api(index < len() ? dnnl_success : dnnl_invalid_arguments,
|
|
"post-ops index is out of range");
|
|
return static_cast<primitive::kind>(
|
|
dnnl_post_ops_get_kind(get(), index));
|
|
}
|
|
|
|
/// Appends an accumulation (sum) post-op. Prior to accumulating the
|
|
/// result, the previous value will be will be reduced by zero point
|
|
/// @p zero_point and multiplied by a scaling factor @p scale.
|
|
///
|
|
/// The kind of this post-op is #dnnl::primitive::kind::sum.
|
|
///
|
|
/// This feature may improve performance for cases like dequantize the
|
|
/// asymmetrically quantized sum's src1 tensor to f32 domain before
|
|
/// performing the sum operation by subtracting @p zero_point before the
|
|
/// scaling.
|
|
///
|
|
/// In the simplest case when the accumulation is the only post-op,
|
|
/// the computations will be `dst[:] := scale * (dst[:] - zero_point) +
|
|
/// op(...)` instead of `dst[:] := op(...)`.
|
|
///
|
|
/// If @p data_type is specified, the original dst tensor will be
|
|
/// reinterpreted as a tensor with the provided data type. Because it is a
|
|
/// reinterpretation, data_type and dst data type should have the same size.
|
|
/// As a result, computations will be `dst[:] <- scale *
|
|
/// (as_data_type(dst[:]) - zero_point) + op(...)` instead of
|
|
/// `dst[:] <- op(...)`.
|
|
///
|
|
/// @note
|
|
/// This post-op executes in-place and does not change the
|
|
/// destination layout.
|
|
///
|
|
/// @param scale Scaling factor.
|
|
/// @param zero_point Zero point.
|
|
/// @param data_type Data type.
|
|
void append_sum(float scale = 1.f, int32_t zero_point = 0,
|
|
memory::data_type data_type = memory::data_type::undef) {
|
|
error::wrap_c_api(dnnl_post_ops_append_sum(get(), scale, zero_point,
|
|
memory::convert_to_c(data_type)),
|
|
"could not append a sum post-op");
|
|
}
|
|
|
|
/// Returns the parameters of an accumulation (sum) post-op.
|
|
///
|
|
/// @param index Index of the sum post-op.
|
|
/// @param scale Scaling factor of the sum post-op.
|
|
void get_params_sum(int index, float &scale) const {
|
|
error::wrap_c_api(dnnl_post_ops_get_params_sum(
|
|
get(), index, &scale, nullptr, nullptr),
|
|
"could not get parameters of a sum post-op");
|
|
}
|
|
|
|
/// Returns the parameters of an accumulation (sum) post-op.
|
|
///
|
|
/// @param index Index of the sum post-op.
|
|
/// @param scale Scaling factor of the sum post-op.
|
|
/// @param data_type Data type of the sum post-op.
|
|
void get_params_sum(
|
|
int index, float &scale, memory::data_type &data_type) const {
|
|
dnnl_data_type_t c_data_type;
|
|
error::wrap_c_api(dnnl_post_ops_get_params_sum(
|
|
get(), index, &scale, nullptr, &c_data_type),
|
|
"could not get parameters of a sum post-op");
|
|
data_type = static_cast<memory::data_type>(c_data_type);
|
|
}
|
|
|
|
/// Returns the parameters of an accumulation (sum) post-op.
|
|
///
|
|
/// @param index Index of the sum post-op.
|
|
/// @param scale Scaling factor of the sum post-op.
|
|
/// @param zero_point Single scalar int32_t value of zeropoint.
|
|
/// @param data_type Data type of the sum post-op.
|
|
void get_params_sum(int index, float &scale, int32_t &zero_point,
|
|
memory::data_type &data_type) const {
|
|
dnnl_data_type_t c_data_type;
|
|
error::wrap_c_api(dnnl_post_ops_get_params_sum(get(), index, &scale,
|
|
&zero_point, &c_data_type),
|
|
"could not get parameters of a sum post-op");
|
|
data_type = static_cast<memory::data_type>(c_data_type);
|
|
}
|
|
|
|
/// Appends an elementwise post-op.
|
|
///
|
|
/// The kind of this post-op is #dnnl::primitive::kind::eltwise.
|
|
///
|
|
/// In the simplest case when the elementwise is the only post-op, the
|
|
/// computations would be `dst[:] := eltwise_op (op(...))` instead
|
|
/// of `dst[:] <- op(...)`, where eltwise_op is configured with the given
|
|
/// parameters.
|
|
///
|
|
/// @param aalgorithm Elementwise algorithm.
|
|
/// @param alpha Alpha parameter for the elementwise algorithm.
|
|
/// @param beta Beta parameter for the elementwise algorithm.
|
|
void append_eltwise(algorithm aalgorithm, float alpha, float beta) {
|
|
error::wrap_c_api(dnnl_post_ops_append_eltwise(
|
|
get(), convert_to_c(aalgorithm), alpha, beta),
|
|
"could not append an elementwise post-op");
|
|
}
|
|
|
|
/// Returns parameters of an elementwise post-op.
|
|
///
|
|
/// @param index Index of the post-op.
|
|
/// @param aalgorithm Output elementwise algorithm kind.
|
|
/// @param alpha Output alpha parameter for the elementwise algorithm.
|
|
/// @param beta Output beta parameter for the elementwise algorithm.
|
|
void get_params_eltwise(
|
|
int index, algorithm &aalgorithm, float &alpha, float &beta) const {
|
|
dnnl_alg_kind_t c_alg;
|
|
error::wrap_c_api(dnnl_post_ops_get_params_eltwise(
|
|
get(), index, &c_alg, &alpha, &beta),
|
|
"could not get parameters of an elementwise post-op");
|
|
aalgorithm = static_cast<dnnl::algorithm>(c_alg);
|
|
}
|
|
|
|
/// Appends a depthwise post-op convolution.
|
|
///
|
|
/// This post-op can only be fused with a 2D 1x1 convolution (convolution
|
|
/// with weights spatial dimension equal to 1 i.e., kh=kw=1).
|
|
///
|
|
/// The kind of this post-op is #dnnl_convolution.
|
|
///
|
|
/// The number of outputs for primitive remain same as before. The output
|
|
/// spatial size can be derived as below:
|
|
///
|
|
/// output_height = ceil(output_height_1x1_convolution, stride)
|
|
/// output_width = ceil(output_width_1x1_convolution, stride)
|
|
///
|
|
/// See @ref dev_guide_attributes_post_ops_depthwise and
|
|
/// @ref dev_guide_attributes_post_ops_depthwise_fusion for more info.
|
|
///
|
|
/// @param weights_data_type Weights data type of depthwise post-op
|
|
/// @param bias_data_type Bias data type of depthwise post-op
|
|
/// @param dst_data_type Output data type of depthwise post-op
|
|
/// @param kernel_size Size of kernel of depthwise post-op
|
|
/// @param stride_size Size of stride of depthwise post-op
|
|
/// @param padding_l_size Size of left and top paddings of depthwise post-op
|
|
void append_dw(memory::data_type weights_data_type,
|
|
memory::data_type bias_data_type, memory::data_type dst_data_type,
|
|
memory::dim kernel_size, memory::dim stride_size,
|
|
memory::dim padding_l_size) {
|
|
|
|
error::wrap_c_api(dnnl_post_ops_append_dw(get(),
|
|
memory::convert_to_c(weights_data_type),
|
|
memory::convert_to_c(bias_data_type),
|
|
memory::convert_to_c(dst_data_type),
|
|
kernel_size, stride_size, padding_l_size),
|
|
"could not append depthwise post-op");
|
|
}
|
|
|
|
/// Returns the parameters of an depthwise post-op.
|
|
///
|
|
/// @param index Index of the elementwise post-op.
|
|
/// @param weights_data_type Weights data type of depthwise post-op
|
|
/// @param bias_data_type Bias data type of depthwise post-op
|
|
/// @param dst_data_type Output data type of depthwise post-op
|
|
/// @param kernel_size Size of kernel of depthwise post-op
|
|
/// @param stride_size Size of stride of depthwise post-op
|
|
/// @param padding_l_size Size of left and top paddings of depthwise post-op
|
|
void get_params_dw(int index, memory::data_type &weights_data_type,
|
|
memory::data_type &bias_data_type, memory::data_type &dst_data_type,
|
|
memory::dim &kernel_size, memory::dim &stride_size,
|
|
memory::dim &padding_l_size) const {
|
|
|
|
dnnl_data_type_t c_weights_data_type;
|
|
dnnl_data_type_t c_bias_data_type;
|
|
dnnl_data_type_t c_dst_data_type;
|
|
dnnl_dim_t c_kernel_size;
|
|
dnnl_dim_t c_stride_size;
|
|
dnnl_dim_t c_padding_l_size;
|
|
error::wrap_c_api(
|
|
dnnl_post_ops_get_params_dw(get(), index, &c_weights_data_type,
|
|
&c_bias_data_type, &c_dst_data_type, &c_kernel_size,
|
|
&c_stride_size, &c_padding_l_size),
|
|
"could not get parameters of depthwise post-op");
|
|
|
|
weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
|
|
bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
|
|
dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
|
|
kernel_size = c_kernel_size;
|
|
stride_size = c_stride_size;
|
|
padding_l_size = c_padding_l_size;
|
|
}
|
|
|
|
/// Appends a binary post-op.
|
|
///
|
|
/// This post operation is categorized as #dnnl_binary.
|
|
///
|
|
/// In the simplest case when the binary is the only post operation, the
|
|
/// computations will be:
|
|
///
|
|
/// dst[:] <- binary_op (dst[:], another_input[:])
|
|
///
|
|
/// where binary_op is configured with the given parameters. binary_op
|
|
/// supports broadcast semantics for a second operand.
|
|
///
|
|
/// @param aalgorithm Binary algorithm for the post-op.
|
|
/// @param src1_desc Memory descriptor of a second operand.
|
|
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
|
|
error::wrap_c_api(dnnl_post_ops_append_binary(get(),
|
|
convert_to_c(aalgorithm), src1_desc.get()),
|
|
"could not append a binary post-op");
|
|
}
|
|
|
|
/// Appends a binary post-op with ternary operators.
|
|
///
|
|
/// This post operation is categorized as #dnnl_binary.
|
|
///
|
|
/// In the simplest case when this is the only post operation, the
|
|
/// computations will be:
|
|
///
|
|
/// dst[:] <- binary_op (dst[:], another_input1[:], another_input2[:])
|
|
///
|
|
/// where binary_op is configured with the given parameters. binary_op
|
|
/// supports broadcast semantics only for the second operand and not for the
|
|
/// third operand.
|
|
///
|
|
/// @param aalgorithm Binary algorithm for the post-op.
|
|
/// @param src1_desc Memory descriptor of the second operand.
|
|
/// @param src2_desc Memory descriptor of the third operand. If the specified
|
|
/// algorithm is not one that requires a ternary input, src2_desc will be
|
|
/// ignored.
|
|
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc,
|
|
const memory::desc &src2_desc) {
|
|
error::wrap_c_api(
|
|
dnnl_post_ops_append_binary_v2(get(), convert_to_c(aalgorithm),
|
|
src1_desc.get(), src2_desc.get()),
|
|
"could not append a binary post-op with ternary operators");
|
|
}
|
|
|
|
/// Returns the parameters of a binary post-op.
|
|
///
|
|
/// @param index Index of the binary post-op.
|
|
/// @param aalgorithm Output binary algorithm kind.
|
|
/// @param src1_desc Output memory descriptor of a second operand.
|
|
void get_params_binary(
|
|
int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
|
|
dnnl_alg_kind_t c_alg;
|
|
const_dnnl_memory_desc_t cdesc;
|
|
error::wrap_c_api(
|
|
dnnl_post_ops_get_params_binary(get(), index, &c_alg, &cdesc),
|
|
"could not get parameters of a binary post-op");
|
|
aalgorithm = static_cast<dnnl::algorithm>(c_alg);
|
|
dnnl_memory_desc_t cloned_md = nullptr;
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
|
|
"could not clone a memory descriptor");
|
|
src1_desc = memory::desc(cloned_md);
|
|
}
|
|
|
|
/// Returns the parameters of a binary post-op with ternary operators.
|
|
///
|
|
/// @param index Index of the binary post-op.
|
|
/// @param aalgorithm Output binary algorithm kind.
|
|
/// @param src1_desc Output memory descriptor of the second operand.
|
|
/// @param src2_desc Output memory descriptor of the third operand.
|
|
void get_params_binary(int index, algorithm &aalgorithm,
|
|
memory::desc &src1_desc, memory::desc &src2_desc) const {
|
|
dnnl_alg_kind_t c_alg;
|
|
const_dnnl_memory_desc_t cdesc1, cdesc2;
|
|
error::wrap_c_api(dnnl_post_ops_get_params_binary_v2(
|
|
get(), index, &c_alg, &cdesc1, &cdesc2),
|
|
"could not get parameters of a binary post-op with ternary "
|
|
"operators");
|
|
aalgorithm = static_cast<dnnl::algorithm>(c_alg);
|
|
dnnl_memory_desc_t cloned_md1 = nullptr;
|
|
dnnl_memory_desc_t cloned_md2 = nullptr;
|
|
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md1, cdesc1),
|
|
"could not clone a memory descriptor");
|
|
src1_desc = memory::desc(cloned_md1);
|
|
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md2, cdesc2),
|
|
"could not clone a memory descriptor");
|
|
src2_desc = memory::desc(cloned_md2);
|
|
}
|
|
|
|
/// Appends a prelu forward post-op.
|
|
///
|
|
/// The kind of this post-op is #dnnl::primitive::kind::prelu.
|
|
///
|
|
/// The post-op can be defined as:
|
|
///
|
|
/// dst[:] <- prelu(dst[:], weights[:])
|
|
/// prelu:
|
|
/// dst[:] <- dst[:] if dst[:] > 0
|
|
/// dst[:] <- dst[:] * weights[:] if dst[:] <= 0
|
|
///
|
|
///
|
|
/// Example usage:
|
|
/// @code
|
|
/// int mb = 32, oc = 32,
|
|
/// oh = 14, ow = 14; // convolution output params
|
|
/// // unique weights per output channel
|
|
/// vector<float> weights = { ... };
|
|
/// int oc_dim = 1; // mb_dim = 0, channel_dim = 1, height_dim = 2, ...
|
|
///
|
|
/// // construct a convolution descriptor
|
|
/// dnnl::convolution::desc conv_d;
|
|
///
|
|
/// dnnl::primitive_attr attr;
|
|
/// attr.append_prelu(1 << oc_dim);
|
|
///
|
|
/// dnnl::primitive_desc conv_pd(conv_d, attr, engine);
|
|
/// memory prelu_weights({{1}, dt::f32, {1}}, eng, weights.data());
|
|
///
|
|
/// std::unordered_map<int, memory> conv_args;
|
|
///
|
|
/// conv_args.insert(
|
|
/// {DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_WEIGHTS, prelu_weights})
|
|
/// @endcode
|
|
///
|
|
/// @note
|
|
/// The order of dimensions does not depend on how elements are laid
|
|
/// out in memory. For example:
|
|
/// - for a 2D CNN activations tensor the order is always (n, c)
|
|
/// - for a 4D CNN activations tensor the order is always (n, c, h, w)
|
|
/// - for a 5D CNN weights tensor the order is always
|
|
/// (g, oc, ic, kh, kw)
|
|
///
|
|
/// Prelu weights tensor is passed in runtime execution phase. Prelu
|
|
/// weights tensor data type is implicitly assumed as f32 using plain
|
|
/// layout (a, ab, acb, acdb, acdeb).
|
|
///
|
|
/// @param mask Defines the correspondence between the output tensor
|
|
/// dimensions and the prelu weights tensor. The set i-th bit indicates
|
|
/// that a dedicated weights value is used for each index along that
|
|
/// dimension. Set the mask to 0 to use a common weights value
|
|
/// for the whole output tensor.
|
|
void append_prelu(int mask) {
|
|
error::wrap_c_api(dnnl_post_ops_append_prelu(get(), mask),
|
|
"could not append a prelu post-op");
|
|
}
|
|
|
|
/// Returns the parameters of a prelu post-op.
|
|
///
|
|
/// @param index Index of the prelu post-op.
|
|
/// @param mask Weights mask of prelu post-op.
|
|
void get_params_prelu(int index, int &mask) const {
|
|
error::wrap_c_api(dnnl_post_ops_get_params_prelu(get(), index, &mask),
|
|
"could not get parameters of a binary post-op");
|
|
}
|
|
};
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
template <>
|
|
struct handle_traits<dnnl_primitive_attr_t> {
|
|
static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
|
|
return dnnl_primitive_attr_destroy(p);
|
|
}
|
|
};
|
|
/// @endcond
|
|
|
|
/// Primitive attributes.
|
|
///
|
|
/// @sa @ref dev_guide_attributes
|
|
struct primitive_attr : public handle<dnnl_primitive_attr_t> {
|
|
using handle<dnnl_primitive_attr_t>::handle;
|
|
|
|
/// Constructs default (empty) primitive attributes.
|
|
primitive_attr() {
|
|
dnnl_primitive_attr_t result;
|
|
error::wrap_c_api(dnnl_primitive_attr_create(&result),
|
|
"could not create primitive attribute");
|
|
reset(result);
|
|
}
|
|
|
|
/// Creates primitive attributes from a C API ::dnnl_primitive_attr_t
|
|
/// handle. The resulting handle is not weak and the C handle will be
|
|
/// destroyed during the destruction of the C++ object.
|
|
///
|
|
/// @param attr The C API primitive attributes.
|
|
primitive_attr(dnnl_primitive_attr_t attr)
|
|
: handle<dnnl_primitive_attr_t>(attr) {}
|
|
|
|
/// Returns the parameters of a dropout attribute.
|
|
///
|
|
/// @param mask_desc Output memory descriptor of a dropout mask.
|
|
void get_dropout(memory::desc &mask_desc) const {
|
|
const_dnnl_memory_desc_t cdesc;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_dropout(get(), &cdesc),
|
|
"could not get parameters of a dropout attribute");
|
|
dnnl_memory_desc_t cloned_md = nullptr;
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
|
|
"could not clone a memory descriptor");
|
|
mask_desc = memory::desc(cloned_md);
|
|
}
|
|
|
|
/// Sets dropout probability.
|
|
///
|
|
/// @param mask_desc Output memory descriptor of a dropout mask.
|
|
void set_dropout(const memory::desc &mask_desc) {
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_set_dropout(get(), mask_desc.get()),
|
|
"could not set dropout primitive attribute");
|
|
}
|
|
|
|
/// Returns the fpmath mode
|
|
fpmath_mode get_fpmath_mode() const {
|
|
dnnl_fpmath_mode_t result;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode(get(), &result),
|
|
"could not get fpmath mode primitive attribute");
|
|
return fpmath_mode(result);
|
|
}
|
|
|
|
/// Returns the fpmath mode
|
|
///
|
|
/// @param mode Specified fpmath mode.
|
|
/// @param apply_to_int Use floating-point arithmetic for integer primitives.
|
|
void get_fpmath_mode(fpmath_mode &mode, bool &apply_to_int) const {
|
|
dnnl_fpmath_mode_t c_mode;
|
|
int c_apply_to_int;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_fpmath_mode_v2(
|
|
get(), &c_mode, &c_apply_to_int),
|
|
"could not get fpmath mode primitive attribute");
|
|
mode = fpmath_mode(c_mode);
|
|
apply_to_int = static_cast<bool>(c_apply_to_int);
|
|
}
|
|
|
|
/// Sets fpmath mode.
|
|
///
|
|
/// @param mode Specified fpmath mode.
|
|
/// @param apply_to_int Boolean. Use of floating-point arithmetic for integer primitives.
|
|
void set_fpmath_mode(fpmath_mode mode, bool apply_to_int = false) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_fpmath_mode_v2(get(),
|
|
dnnl::convert_to_c(mode), apply_to_int),
|
|
"could not set fpmath mode primitive attribute");
|
|
}
|
|
|
|
/// Returns the accumulation mode
|
|
accumulation_mode get_accumulation_mode() const {
|
|
dnnl_accumulation_mode_t result;
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_get_accumulation_mode(get(), &result),
|
|
"could not get accumulation mode primitive attribute");
|
|
return accumulation_mode(result);
|
|
}
|
|
|
|
/// Sets accumulation mode.
|
|
///
|
|
/// @param mode Specified accumulation mode.
|
|
void set_accumulation_mode(accumulation_mode mode) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_accumulation_mode(
|
|
get(), dnnl::convert_to_c(mode)),
|
|
"could not set accumulation mode primitive attribute");
|
|
}
|
|
|
|
/// Returns the deterministic attribute value
|
|
bool get_deterministic() const {
|
|
int result;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_deterministic(get(), &result),
|
|
"could not get deterministic primitive attribute");
|
|
return static_cast<bool>(result);
|
|
}
|
|
|
|
/// Sets deterministic attribute value
|
|
///
|
|
/// @param value Specified deterministic mode.
|
|
void set_deterministic(bool value) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_deterministic(
|
|
get(), static_cast<int>(value)),
|
|
"could not set deterministic primitive attribute");
|
|
}
|
|
|
|
/// Returns the rounding mode attribute value
|
|
///
|
|
/// @param arg Argument for which rounding mode query applies.
|
|
/// @returns The rounding mode applied to the specified argument.
|
|
rounding_mode get_rounding_mode(int arg) const {
|
|
dnnl_rounding_mode_t result;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_rounding(get(), arg, &result),
|
|
"could not get rounding mode primitive attribute");
|
|
return rounding_mode(result);
|
|
}
|
|
|
|
/// Sets the rounding mode attribute value for a given argument
|
|
///
|
|
/// @param arg Argument for which to set rounding mode.
|
|
/// @param mode Rounding mode to apply.
|
|
void set_rounding_mode(int arg, rounding_mode mode) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_rounding(
|
|
get(), arg, convert_to_c(mode)),
|
|
"could not set rounding mode primitive attribute");
|
|
}
|
|
|
|
/// Returns the scratchpad mode.
|
|
scratchpad_mode get_scratchpad_mode() const {
|
|
dnnl_scratchpad_mode_t result;
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_get_scratchpad_mode(get(), &result),
|
|
"could not get scratchpad mode primitive attribute");
|
|
return scratchpad_mode(result);
|
|
}
|
|
|
|
/// Sets scratchpad mode.
|
|
///
|
|
/// @param mode Specified scratchpad mode.
|
|
void set_scratchpad_mode(scratchpad_mode mode) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_scratchpad_mode(
|
|
get(), dnnl::convert_to_c(mode)),
|
|
"could not set scratchpad mode primitive attribute");
|
|
}
|
|
|
|
/// Sets scaling factors for primitive operations for a given memory
|
|
/// argument. The scaling factors must be passed at execution time
|
|
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_scales_mask
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the tensor dimensions and the @p scales
|
|
/// vector. The set i-th bit indicates that a dedicated scaling factor
|
|
/// is used for each index along that dimension. Set the mask to 0 to
|
|
/// use a common scaling factor for the whole output tensor.
|
|
void set_scales_mask(int arg, int mask) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_scales_mask(get(), arg, mask),
|
|
"could not set scales primitive attribute");
|
|
}
|
|
|
|
/// Sets primitive attributes scaling factors for a given memory
|
|
/// argument. The scaling factors must be passed at execution time as
|
|
/// an argument with index #DNNL_ARG_ATTR_SCALES | arg.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_scales_v3
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive execute() call.
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the tensor dimensions and the @p scales array.
|
|
/// The set i-th bit indicates that a dedicated scaling factor is used for
|
|
/// each index along that dimension. Set the mask to 0 to use a common
|
|
/// scaling factor for the whole tensor.
|
|
/// @param groups Scaling factors correspondence groups that define the
|
|
/// correspondence between the tensor dimensions and the scales array.
|
|
/// The group dimensions should only be provided for each logical dimension
|
|
/// that has correspondence mask @p mask set.
|
|
/// @param data_type Scaling factors data_type.
|
|
/// @param is_on_host Indicates whether the scaling factor is a host-side scalar.
|
|
/// @param qmode Quantization mode, can be #quantization_mode::static_sazp
|
|
/// or #quantization_mode::dynamic_mx
|
|
void set_scales(int arg, int mask, const memory::dims &groups,
|
|
memory::data_type data_type = memory::data_type::f32,
|
|
bool is_on_host = false,
|
|
quantization_mode qmode = quantization_mode::static_sazp) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, mask,
|
|
(int)groups.size(), groups.data(),
|
|
memory::convert_to_c(data_type), is_on_host,
|
|
convert_to_c(qmode)),
|
|
"could not set scales primitive attribute");
|
|
}
|
|
|
|
/// Sets a single host-side scalar scaling
|
|
/// factor for the specified memory argument. The scaling factor should be
|
|
/// passed as a host scalar memory object at execution time with index
|
|
/// #DNNL_ARG_ATTR_SCALES | arg.
|
|
///
|
|
/// @note Using this API to set the scaling factor implies that the scales
|
|
/// attribute has `mask == 0` and an empty groups vector.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_scales_v2
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param data_type Scaling factors data_type.
|
|
void set_host_scale(
|
|
int arg, memory::data_type data_type = memory::data_type::f32) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, 0, 0,
|
|
nullptr, memory::convert_to_c(data_type), 1,
|
|
dnnl_quantization_mode_static_sazp),
|
|
"could not set scales primitive attribute");
|
|
}
|
|
|
|
/// Sets zero points for primitive operations for a given memory argument.
|
|
/// The zero points must be passed at execution time as an argument with
|
|
/// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_zero_points_mask
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param mask Zero point correspondence mask that defines the
|
|
/// correspondence between the tensor dimensions and the @p
|
|
/// zero_points vector. The set i-th bit indicates that a dedicated
|
|
/// zero point is used for each index along that dimension. Set the
|
|
/// mask to 0 to use a common zero point for the whole output tensor.
|
|
void set_zero_points_mask(int arg, int mask) {
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_set_zero_points_mask(get(), arg, mask),
|
|
"could not set zero points primitive attribute");
|
|
}
|
|
|
|
/// Sets zero points for primitive operations for a given memory argument.
|
|
/// The zero points must be passed at execution time as an argument with
|
|
/// index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
|
|
///
|
|
/// @note If `is_on_host` is true, sets a single host-side zero point
|
|
/// for the specified memory argument. The zero point should be
|
|
/// passed as a host scalar memory object at execution time with index
|
|
/// #DNNL_ARG_ATTR_ZERO_POINTS | arg.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_zero_points
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param mask Zero point correspondence mask that defines the
|
|
/// correspondence between the tensor dimensions and the zero points
|
|
/// vector. The set i-th bit indicates that a dedicated zero point is
|
|
/// used for each index along that dimension. Set the mask to 0 to use
|
|
/// a common zero point for the whole output tensor.
|
|
/// @param groups Zero point factors correspondence groups that define the
|
|
/// correspondence between the tensor dimensions and the zero points
|
|
/// array.
|
|
/// The set i-th dimension indicates a number of groups of zero point
|
|
/// factors used for that logical dimension in a memory indicated by
|
|
/// @p arg.
|
|
/// @param data_type Zero point factors data_type.
|
|
/// @param is_on_host Indicates whether the zero point is a host-side scalar.
|
|
void set_zero_points(int arg, int mask, const memory::dims &groups,
|
|
memory::data_type data_type = memory::data_type::s32,
|
|
bool is_on_host = false) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_zero_points_v2(get(), arg,
|
|
mask, (int)groups.size(), groups.data(),
|
|
memory::convert_to_c(data_type), is_on_host),
|
|
"could not set zero points primitive attribute");
|
|
}
|
|
|
|
/// Sets a single host-side zero point for the specified memory argument.
|
|
/// The zero point should be passed as a host scalar memory object at
|
|
/// execution time with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
|
|
///
|
|
/// @note Using this API to set the zero point implies that the zero
|
|
/// point attribute has `mask == 0` and an empty groups vector.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_zero_points_v2
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param data_type Zero point data type.
|
|
void set_host_zero_point(
|
|
int arg, memory::data_type data_type = memory::data_type::s32) {
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_set_zero_points_v2(get(), arg, 0, 0,
|
|
nullptr, memory::convert_to_c(data_type), 1),
|
|
"could not set zero points primitive attribute");
|
|
}
|
|
|
|
/// Sets precomputed reductions for primitive operations for a given memory
|
|
/// argument. The precomputed reductions must be passed at execution time as
|
|
/// an argument with index #DNNL_ARG_ATTR_PRECOMPUTED_REDUCTIONS | arg.
|
|
///
|
|
/// @sa dnnl_primitive_attr_set_precomputed_reductions
|
|
///
|
|
/// @param arg Parameter argument index as passed to the
|
|
/// primitive::execute() call.
|
|
/// @param mask Precomputed reductions correspondence mask that defines the
|
|
/// correspondence between the tensor dimensions and the precomputed
|
|
/// reductions vector. The set i-th bit indicates that a dedicated
|
|
/// precomputed reduction point is used for each index along that
|
|
/// dimension.
|
|
/// @param groups Precomputed reduction factors correspondence groups that
|
|
/// define the correspondence between the tensor dimensions and the
|
|
/// precomputed reductions array.
|
|
/// The set i-th dimension indicates a number of groups of precomputed
|
|
/// reduction factors used for that logical dimension in a memory
|
|
/// indicated by @p arg.
|
|
/// @param data_type Precomputed reduction factors data_type.
|
|
void set_precomputed_reductions(int arg, int mask,
|
|
const memory::dims &groups,
|
|
memory::data_type data_type = memory::data_type::s32) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_precomputed_reductions(get(),
|
|
arg, mask, (int)groups.size(), groups.data(),
|
|
memory::convert_to_c(data_type)),
|
|
"could not set precomputed reductions primitive attribute");
|
|
}
|
|
|
|
/// Returns post-ops previously set via set_post_ops().
|
|
///
|
|
/// @returns Post-ops.
|
|
post_ops get_post_ops() const {
|
|
const_dnnl_post_ops_t const_c_post_ops;
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_get_post_ops(get(), &const_c_post_ops),
|
|
"could not get post-ops primitive attribute");
|
|
dnnl_post_ops_t c_post_ops;
|
|
error::wrap_c_api(dnnl_post_ops_clone(&c_post_ops, const_c_post_ops),
|
|
"could not clone post-ops primitive attribute");
|
|
return post_ops(c_post_ops);
|
|
}
|
|
|
|
/// Sets post-ops.
|
|
///
|
|
/// @note
|
|
/// There is no way to check whether the post-ops would be supported
|
|
/// by the target primitive. Any error will be reported
|
|
/// by the respective primitive descriptor constructor.
|
|
///
|
|
/// @param ops Post-ops object to copy post-ops from.
|
|
void set_post_ops(const post_ops &ops) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_post_ops(get(), ops.get()),
|
|
"could not set post-ops primitive attribute");
|
|
}
|
|
|
|
/// Sets quantization scale and shift parameters for RNN data tensors.
|
|
///
|
|
/// For performance reasons, the low-precision configuration of the RNN
|
|
/// primitives expect input activations to have the unsigned 8-bit integer
|
|
/// data type. The scale and shift parameters are used to quantize
|
|
/// floating-point data to unsigned integer and must be passed to the RNN
|
|
/// primitive using attributes.
|
|
///
|
|
/// The quantization formula is `scale * data + shift`.
|
|
///
|
|
/// Example usage:
|
|
/// @code
|
|
/// // RNN parameters
|
|
/// int l = 2, t = 2, mb = 32, sic = 32, slc = 32, dic = 32, dlc = 32;
|
|
/// // Activations quantization parameters
|
|
/// float scale = 63.f, shift = 64.f;
|
|
///
|
|
/// primitive_attr attr;
|
|
///
|
|
/// // Set scale and shift for int8 quantization of activation
|
|
/// attr.set_rnn_data_qparams(scale, shift);
|
|
///
|
|
/// // Create an RNN primitive descriptor.
|
|
/// vanilla_rnn_forward::primitive_desc rnn_d(
|
|
/// engine, /* arguments */, attr);
|
|
/// @endcode
|
|
///
|
|
/// @note
|
|
/// Quantization scale and shift are common for src_layer, src_iter,
|
|
/// dst_iter, and dst_layer.
|
|
///
|
|
/// @param scale The value to scale the data by.
|
|
/// @param shift The value to shift the data by.
|
|
void set_rnn_data_qparams(float scale, float shift) {
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_set_rnn_data_qparams(get(), scale, shift),
|
|
"could not set RNN data quantization parameters primitive "
|
|
"attribute");
|
|
}
|
|
|
|
/// Returns the quantization scale and shift parameters for RNN data
|
|
/// tensors.
|
|
///
|
|
/// @note
|
|
/// Quantization scale and shift are common for src_layer, src_iter,
|
|
/// dst_iter, and dst_layer.
|
|
///
|
|
/// @param scale The value to scale the data by.
|
|
/// @param shift The value to shift the data by.
|
|
void get_rnn_data_qparams(float &scale, float &shift) {
|
|
float c_scale, c_shift;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_rnn_data_qparams(
|
|
get(), &c_scale, &c_shift),
|
|
"could not set RNN data quantization parameters primitive "
|
|
"attribute");
|
|
scale = c_scale;
|
|
shift = c_shift;
|
|
}
|
|
|
|
/// Sets quantization scaling factors for RNN weights tensors. The
|
|
/// low-precision configuration of the RNN primitives expect input weights
|
|
/// to use the signed 8-bit integer data type. The scaling factors are
|
|
/// used to quantize floating-point data to signed integer and must be
|
|
/// passed to RNN primitives using attributes.
|
|
///
|
|
/// @note
|
|
/// The dimension order is always native and does not depend on the
|
|
/// actual layout used. For example, five-dimensional weights always
|
|
/// have (l, d, i, g, o) logical dimension ordering.
|
|
///
|
|
/// @note
|
|
/// Quantization scales are common for weights_layer and
|
|
/// weights_iteration
|
|
///
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the output tensor dimensions and the @p
|
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
|
/// factor should be used each index along that dimension. Set the
|
|
/// mask to 0 to use a common scaling factor for the whole output
|
|
/// tensor.
|
|
/// @param scales Constant vector of output scaling factors. The following
|
|
/// equality must hold:
|
|
/// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
|
|
/// Violations can only be detected when the attributes are used to
|
|
/// create a primitive descriptor.
|
|
void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
|
|
error::wrap_c_api(dnnl_primitive_attr_set_rnn_weights_qparams(get(),
|
|
(int)scales.size(), mask, scales.data()),
|
|
"could not set RNN weights quantization parameters primitive "
|
|
"attribute");
|
|
}
|
|
|
|
/// Returns the quantization scaling factors for RNN projection weights
|
|
/// tensors.
|
|
///
|
|
/// @note
|
|
/// The dimension order is always native and does not depend on the
|
|
/// actual layout used. For example, five-dimensional weights always
|
|
/// have (l, d, i, g, o) logical dimension ordering.
|
|
///
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the output tensor dimensions and the @p
|
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
|
/// factor should be used each index along that dimension. Set the
|
|
/// mask to 0 to use a common scaling factor for the whole output
|
|
/// tensor.
|
|
/// @param scales Constant vector of output scaling factors. The following
|
|
/// equality must hold:
|
|
/// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
|
|
/// Violations can only be detected when the attributes are used to
|
|
/// create a primitive descriptor.
|
|
void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
|
|
dnnl_dim_t count;
|
|
int c_mask;
|
|
const float *c_scales;
|
|
error::wrap_c_api(dnnl_primitive_attr_get_rnn_weights_qparams(
|
|
get(), &count, &c_mask, &c_scales),
|
|
"could not get primitive RNN weights quantization "
|
|
"parameters attributes");
|
|
scales.resize(count);
|
|
|
|
mask = c_mask;
|
|
for (dnnl_dim_t c = 0; c < count; c++)
|
|
scales[c] = c_scales[c];
|
|
}
|
|
|
|
/// Sets quantization scaling factors for RNN projection weights tensors.
|
|
// The low-precision configuration of the RNN primitives expect input
|
|
// weights to use the signed 8-bit integer data type. The scaling factors
|
|
// are used to quantize floating-point data to signed integer and must be
|
|
/// passed to RNN primitives using attributes.
|
|
///
|
|
/// @note
|
|
/// The dimension order is always native and does not depend on the
|
|
/// actual layout used. For example, five-dimensional weights always
|
|
/// have (l, d, i, g, o) logical dimension ordering.
|
|
///
|
|
/// @note
|
|
/// Quantization scales are common for weights_layer and
|
|
/// weights_iteration
|
|
///
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the output tensor dimensions and the @p
|
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
|
/// factor should be used each index along that dimension. Set the
|
|
/// mask to 0 to use a common scaling factor for the whole output
|
|
/// tensor.
|
|
/// @param scales Constant vector of output scaling factors. The following
|
|
/// equality must hold:
|
|
/// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
|
|
/// Violations can only be detected when the attributes are used to
|
|
/// create a primitive descriptor.
|
|
void set_rnn_weights_projection_qparams(
|
|
int mask, const std::vector<float> &scales) {
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_set_rnn_weights_projection_qparams(
|
|
get(), (int)scales.size(), mask, scales.data()),
|
|
"could not set primitive RNN weights projection quantization "
|
|
"parameters attributes");
|
|
}
|
|
|
|
/// Returns the quantization scaling factors for RNN projection weights
|
|
/// tensors.
|
|
///
|
|
/// @note
|
|
/// The dimension order is always native and does not depend on the
|
|
/// actual layout used. For example, five-dimensional weights always
|
|
/// have (l, d, i, g, o) logical dimension ordering.
|
|
///
|
|
/// @param mask Scaling factors correspondence mask that defines the
|
|
/// correspondence between the output tensor dimensions and the @p
|
|
/// scales vector. The set i-th bit indicates that a dedicated scaling
|
|
/// factor should be used each index along that dimension. Set the
|
|
/// mask to 0 to use a common scaling factor for the whole output
|
|
/// tensor.
|
|
/// @param scales Constant vector of output scaling factors. The following
|
|
/// equality must hold:
|
|
/// \f$scales.size() = \prod\limits_{d \in mask} weights.dims[d].\f$
|
|
/// Violations can only be detected when the attributes are used to
|
|
/// create a primitive descriptor.
|
|
void get_rnn_weights_projection_qparams(
|
|
int &mask, std::vector<float> &scales) {
|
|
dnnl_dim_t count;
|
|
int c_mask;
|
|
const float *c_scales;
|
|
error::wrap_c_api(
|
|
dnnl_primitive_attr_get_rnn_weights_projection_qparams(
|
|
get(), &count, &c_mask, &c_scales),
|
|
"could not get primitive RNN weights projection quantization "
|
|
"parameters attributes");
|
|
scales.resize(count);
|
|
|
|
mask = c_mask;
|
|
for (dnnl_dim_t c = 0; c < count; c++)
|
|
scales[c] = c_scales[c];
|
|
}
|
|
};
|
|
|
|
/// @} dnnl_api_attributes
|
|
|
|
/// @addtogroup dnnl_api_primitives_common
|
|
/// @{
|
|
|
|
/// Base class for all primitive descriptors.
|
|
struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
|
|
using handle<dnnl_primitive_desc_t>::handle;
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc_base() = default;
|
|
|
|
/// Returns the engine of the primitive descriptor.
|
|
/// @returns The engine of the primitive descriptor.
|
|
engine get_engine() const { return query_engine(query::engine); }
|
|
|
|
/// Returns implementation name.
|
|
/// @returns The implementation name.
|
|
const char *impl_info_str() const {
|
|
const char *res;
|
|
error::wrap_c_api(dnnl_primitive_desc_query(
|
|
get(), dnnl_query_impl_info_str, 0, &res),
|
|
"could not retrieve implementation info string from a "
|
|
"primitive descriptor");
|
|
return res;
|
|
}
|
|
|
|
/// Returns a memory::dim value (same as int64_t).
|
|
/// @param what The value to query.
|
|
/// @returns The result of the query.
|
|
memory::dim query_s64(query what) const {
|
|
memory::dim res;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl::convert_to_c(what), 0, &res);
|
|
return status == dnnl_success ? res : 0;
|
|
}
|
|
|
|
/// Returns strides.
|
|
/// @returns Strides.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive does not have
|
|
/// a strides parameter.
|
|
memory::dims get_strides() const { return query_dims(query::strides); }
|
|
|
|
/// Returns dilations.
|
|
/// @returns Dilations.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive does not have
|
|
/// a dilations parameter.
|
|
memory::dims get_dilations() const { return query_dims(query::dilations); }
|
|
|
|
/// Returns a left padding.
|
|
/// @returns A left padding.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive does not have
|
|
/// a left padding parameter.
|
|
memory::dims get_padding_l() const { return query_dims(query::padding_l); }
|
|
|
|
/// Returns a right padding.
|
|
/// @returns A right padding.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive does not have
|
|
/// a right padding parameter.
|
|
memory::dims get_padding_r() const { return query_dims(query::padding_r); }
|
|
|
|
/// Returns an epsilon.
|
|
/// @returns An epsilon.
|
|
/// @returns Zero if the primitive does not have an epsilon parameter.
|
|
float get_epsilon() const { return query_f32(query::epsilon_f32); }
|
|
|
|
/// Returns flags.
|
|
/// @tparam T Flags enumeration type.
|
|
/// @returns Flags.
|
|
/// @returns Zero if the primitive does not have a flags parameter.
|
|
template <typename T = unsigned>
|
|
T get_flags() const {
|
|
unsigned res;
|
|
dnnl_status_t status
|
|
= dnnl_primitive_desc_query(get(), dnnl_query_flags, 0, &res);
|
|
return static_cast<T>(status == dnnl_success ? res : 0x0U);
|
|
}
|
|
|
|
/// Returns an algorithm kind.
|
|
/// @returns An algorithm kind.
|
|
/// @returns #dnnl::algorithm::undef if the primitive does not have an
|
|
/// algorithm parameter.
|
|
dnnl::algorithm get_algorithm() const { return query_alg(query::alg_kind); }
|
|
|
|
/// Returns an alpha.
|
|
/// @returns An alpha.
|
|
/// @returns Zero if the primitive does not have an alpha parameter.
|
|
float get_alpha() const { return query_f32(query::alpha_f32); }
|
|
|
|
/// Returns a beta.
|
|
/// @returns A beta.
|
|
/// @returns Zero if the primitive does not have a beta parameter.
|
|
float get_beta() const { return query_f32(query::beta_f32); }
|
|
|
|
/// Returns an axis.
|
|
/// @returns An axis.
|
|
/// @returns A negative number if the primitive does not have an axis
|
|
/// parameter.
|
|
int get_axis() const {
|
|
int res;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl_query_axis_s32, 0, &res);
|
|
return status == dnnl_success ? res : -1;
|
|
}
|
|
|
|
/// Returns an LRN local size parameter.
|
|
/// @returns An LRN local size parameter.
|
|
/// @returns Zero if the primitive does not have an LRN local size
|
|
/// parameter.
|
|
memory::dim get_local_size() const {
|
|
return query_s64(query::local_size_s64);
|
|
}
|
|
|
|
/// Returns an LRN K parameter.
|
|
/// @returns An LRN K parameter.
|
|
/// @returns Zero if the primitive does not have an LRN K parameter.
|
|
float get_k() const { return query_f32(query::k_f32); }
|
|
|
|
/// Returns a reduction P parameter.
|
|
/// @returns A reduction P parameter.
|
|
/// @returns Zero if the primitive does not have a reduction P parameter.
|
|
float get_p() const { return query_f32(query::p_f32); }
|
|
|
|
/// Returns a resampling factors parameters.
|
|
/// @returns A vector of factors.
|
|
/// @returns An empty vector if the primitive does not have a resampling
|
|
/// factors parameter.
|
|
std::vector<float> get_factors() const {
|
|
float *factors;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl_query_factors, 0, &factors);
|
|
|
|
const bool is_backward = get_prop_kind() != prop_kind::forward_training
|
|
&& get_prop_kind() != prop_kind::forward_inference;
|
|
const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
|
|
is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
|
|
|
|
int ndims;
|
|
error::wrap_c_api(
|
|
dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
|
|
"could not query ndims from a memory descriptor");
|
|
|
|
return status == dnnl_success
|
|
? std::vector<float>(factors, factors + (ndims - 2))
|
|
: std::vector<float> {};
|
|
}
|
|
|
|
/// Returns an RNN cell kind parameter.
|
|
/// @returns An RNN cell kind parameter.
|
|
/// @returns #dnnl::algorithm::undef if the primitive does not have an
|
|
/// RNN cell kind parameter.
|
|
dnnl::algorithm get_cell_kind() const {
|
|
return query_alg(query::cell_kind);
|
|
}
|
|
|
|
/// Returns an RNN direction parameter.
|
|
/// @returns An RNN direction parameter.
|
|
/// @returns #dnnl::rnn_direction::undef if the primitive does not have
|
|
/// an RNN direction parameter.
|
|
dnnl::rnn_direction get_direction() const {
|
|
dnnl_rnn_direction_t direction;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl_query_direction, 0, &direction);
|
|
return status == dnnl_success
|
|
? static_cast<dnnl::rnn_direction>(direction)
|
|
: dnnl::rnn_direction::undef;
|
|
}
|
|
|
|
/// Returns an RNN activation kind parameter.
|
|
/// @returns An RNN activation kind parameter.
|
|
/// @returns #dnnl::algorithm::undef if the primitive does not have an
|
|
/// RNN activation kind parameter.
|
|
dnnl::algorithm get_activation_kind() const {
|
|
return query_alg(query::activation_kind);
|
|
}
|
|
|
|
/// Returns a pooling kernel parameter.
|
|
/// @returns A pooling kernel parameter.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive does not have
|
|
/// a pooling kernel parameter.
|
|
memory::dims get_kernel() const { return query_dims(query::kernel); }
|
|
|
|
/// Returns a group size parameter.
|
|
/// @returns A group size parameter.
|
|
/// @returns Zero if the primitive does not have a group size
|
|
/// parameter.
|
|
memory::dim get_group_size() const {
|
|
return query_s64(query::group_size_s64);
|
|
}
|
|
|
|
/// Returns a propagation kind.
|
|
/// @returns A propagation kind.
|
|
/// @returns #dnnl::prop_kind::undef if the primitive does not have
|
|
/// a propagation parameter.
|
|
dnnl::prop_kind get_prop_kind() const {
|
|
dnnl_prop_kind_t prop_kind;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl_query_prop_kind, 0, &prop_kind);
|
|
return status == dnnl_success ? static_cast<dnnl::prop_kind>(prop_kind)
|
|
: dnnl::prop_kind::undef;
|
|
}
|
|
|
|
/// Returns a memory descriptor.
|
|
///
|
|
/// @note
|
|
/// There are also convenience methods
|
|
/// #dnnl::primitive_desc_base::src_desc(),
|
|
/// #dnnl::primitive_desc_base::dst_desc(), and others.
|
|
///
|
|
/// @param what The kind of parameter to query; can be
|
|
/// #dnnl::query::src_md, #dnnl::query::dst_md, etc.
|
|
/// @param idx Index of the parameter. For example, convolution bias can
|
|
/// be queried with what = #dnnl::query::weights_md and idx = 1.
|
|
/// @returns The requested memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// parameter of the specified kind or index.
|
|
memory::desc query_md(query what, int idx = 0) const {
|
|
std::vector<query> valid_q {query::src_md, query::diff_src_md,
|
|
query::weights_md, query::diff_weights_md, query::dst_md,
|
|
query::diff_dst_md, query::workspace_md, query::scratchpad_md,
|
|
query::exec_arg_md};
|
|
if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
|
|
[=](query q) { return what == q; }))
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"memory descriptor query is invalid");
|
|
|
|
const_dnnl_memory_desc_t cdesc = dnnl_primitive_desc_query_md(
|
|
get(), dnnl::convert_to_c(what), idx);
|
|
if (!cdesc) return memory::desc();
|
|
|
|
dnnl_memory_desc_t cloned_md = nullptr;
|
|
error::wrap_c_api(dnnl_memory_desc_clone(&cloned_md, cdesc),
|
|
"could not clone a memory descriptor");
|
|
|
|
return memory::desc(cloned_md);
|
|
}
|
|
|
|
/// Returns a source memory descriptor.
|
|
/// @param idx Source index.
|
|
/// @returns Source memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// source parameter with index @p idx.
|
|
memory::desc src_desc(int idx) const {
|
|
return query_md(query::src_md, idx);
|
|
}
|
|
|
|
/// Returns a destination memory descriptor.
|
|
/// @param idx Destination index.
|
|
/// @returns Destination memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// destination parameter with index @p idx.
|
|
memory::desc dst_desc(int idx) const {
|
|
return query_md(query::dst_md, idx);
|
|
}
|
|
|
|
/// Returns a weights memory descriptor.
|
|
/// @param idx Weights index.
|
|
/// @returns Weights memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// weights parameter with index @p idx.
|
|
memory::desc weights_desc(int idx) const {
|
|
return query_md(query::weights_md, idx);
|
|
}
|
|
|
|
/// Returns a diff source memory descriptor.
|
|
/// @param idx Diff source index.
|
|
/// @returns Diff source memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff source parameter with index @p idx.
|
|
memory::desc diff_src_desc(int idx) const {
|
|
return query_md(query::diff_src_md, idx);
|
|
}
|
|
|
|
/// Returns a diff destination memory descriptor.
|
|
/// @param idx Diff destination index.
|
|
/// @returns Diff destination memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff destination parameter with index @p idx.
|
|
memory::desc diff_dst_desc(int idx) const {
|
|
return query_md(query::diff_dst_md, idx);
|
|
}
|
|
|
|
/// Returns a diff weights memory descriptor.
|
|
/// @param idx Diff weights index.
|
|
/// @returns Diff weights memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff weights parameter with index @p idx.
|
|
memory::desc diff_weights_desc(int idx) const {
|
|
return query_md(query::diff_weights_md, idx);
|
|
}
|
|
|
|
// Separate versions without the index argument for documentation
|
|
// purposes.
|
|
|
|
/// Returns a source memory descriptor.
|
|
/// @returns Source memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// source parameter.
|
|
memory::desc src_desc() const { return src_desc(0); }
|
|
|
|
/// Returns a destination memory descriptor.
|
|
/// @returns Destination memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// destination parameter.
|
|
memory::desc dst_desc() const { return dst_desc(0); }
|
|
|
|
/// Returns a weights memory descriptor.
|
|
/// @returns Weights memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// weights parameter.
|
|
memory::desc weights_desc() const { return weights_desc(0); }
|
|
|
|
/// Returns a diff source memory descriptor.
|
|
/// @returns Diff source memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff source memory with.
|
|
memory::desc diff_src_desc() const { return diff_src_desc(0); }
|
|
|
|
/// Returns a diff destination memory descriptor.
|
|
/// @returns Diff destination memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff destination parameter.
|
|
memory::desc diff_dst_desc() const { return diff_dst_desc(0); }
|
|
|
|
/// Returns a diff weights memory descriptor.
|
|
/// @returns Diff weights memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff weights parameter.
|
|
memory::desc diff_weights_desc() const { return diff_weights_desc(0); }
|
|
|
|
/// Returns the workspace memory descriptor.
|
|
/// @returns Workspace memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not require
|
|
/// workspace parameter.
|
|
memory::desc workspace_desc() const {
|
|
return query_md(query::workspace_md, 0);
|
|
}
|
|
|
|
/// Returns the scratchpad memory descriptor.
|
|
/// @returns scratchpad memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not require
|
|
/// scratchpad parameter.
|
|
/// @sa @ref dev_guide_attributes_scratchpad
|
|
memory::desc scratchpad_desc() const {
|
|
return query_md(query::scratchpad_md, 0);
|
|
}
|
|
|
|
/// Returns the engine on which the scratchpad memory is located.
|
|
/// @returns The engine on which the scratchpad memory is located.
|
|
engine scratchpad_engine() const {
|
|
dnnl_engine_t c_engine;
|
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
|
dnnl::convert_to_c(query::scratchpad_engine),
|
|
0, &c_engine),
|
|
"could not retrieve scratchpad engine from a primitive "
|
|
"descriptor");
|
|
return engine(c_engine, true);
|
|
}
|
|
|
|
/// Returns the primitive attributes.
|
|
/// @returns The primitive attributes.
|
|
primitive_attr get_primitive_attr() const {
|
|
const_dnnl_primitive_attr_t const_c_attr;
|
|
error::wrap_c_api(dnnl_primitive_desc_get_attr(get(), &const_c_attr),
|
|
"could not get attributes from a primitive descriptor");
|
|
dnnl_primitive_attr_t c_attr;
|
|
error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
|
|
"could not clone primitive attributes");
|
|
return primitive_attr(c_attr);
|
|
}
|
|
|
|
/// Returns the kind of the primitive descriptor.
|
|
/// @returns The kind of the primitive descriptor.
|
|
dnnl::primitive::kind get_kind() const {
|
|
dnnl_primitive_kind_t kind;
|
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
|
dnnl_query_primitive_kind, 0, (void *)&kind),
|
|
"could not get primitive kind from a primitive descriptor");
|
|
return static_cast<dnnl::primitive::kind>(kind);
|
|
}
|
|
|
|
/// Returns the cache blob ID of the primitive descriptor.
|
|
/// @returns The cache blob ID of the primitive descriptor.
|
|
std::vector<uint8_t> get_cache_blob_id() const {
|
|
dnnl_dim_t count;
|
|
const uint8_t *c_id;
|
|
error::wrap_c_api(
|
|
dnnl_primitive_desc_query(get(),
|
|
dnnl::convert_to_c(query::cache_blob_id_size_s64), 0,
|
|
(void *)&count),
|
|
"could not get size of cache blob ID from a primitive "
|
|
"descriptor");
|
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
|
dnnl::convert_to_c(query::cache_blob_id), 0,
|
|
(void **)&c_id),
|
|
"could not get cache blob ID from a primitive descriptor");
|
|
std::vector<uint8_t> id(c_id, c_id + count);
|
|
return id;
|
|
}
|
|
|
|
protected:
|
|
/// Returns a float value.
|
|
/// @param what The value to query.
|
|
/// @returns The result of the query.
|
|
/// @returns Zero if the primitive doesn't support the query.
|
|
float query_f32(query what) const {
|
|
float res;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl::convert_to_c(what), 0, &res);
|
|
return status == dnnl_success ? res : 0.0f;
|
|
}
|
|
|
|
/// Returns an #dnnl::algorithm value.
|
|
/// @param what The value to query.
|
|
/// @returns The result of the query.
|
|
/// @returns #dnnl::algorithm::undef if the primitive doesn't support
|
|
/// the query.
|
|
algorithm query_alg(query what) const {
|
|
dnnl_alg_kind_t res;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl::convert_to_c(what), 0, &res);
|
|
return status == dnnl_success ? static_cast<dnnl::algorithm>(res)
|
|
: algorithm::undef;
|
|
}
|
|
|
|
/// Returns a memory::dims value.
|
|
/// @param what The value to query.
|
|
/// @returns The result of the query.
|
|
/// @returns An empty #dnnl::memory::dims if the primitive doesn't support
|
|
/// the query.
|
|
memory::dims query_dims(query what) const {
|
|
const bool is_backward = get_prop_kind() != prop_kind::forward_training
|
|
&& get_prop_kind() != prop_kind::forward_inference;
|
|
const_dnnl_memory_desc_t md = dnnl_primitive_desc_query_md(get(),
|
|
is_backward ? dnnl_query_diff_dst_md : dnnl_query_dst_md, 0);
|
|
|
|
int nspatial_dims = 0;
|
|
if (md) {
|
|
int ndims;
|
|
error::wrap_c_api(
|
|
dnnl_memory_desc_query(md, dnnl_query_ndims_s32, &ndims),
|
|
"could not query ndims from a memory descriptor");
|
|
nspatial_dims = ndims - 2;
|
|
}
|
|
|
|
dnnl_dims_t *c_dims;
|
|
dnnl_status_t status = dnnl_primitive_desc_query(
|
|
get(), dnnl::convert_to_c(what), 0, &c_dims);
|
|
return status == dnnl_success
|
|
? memory::dims(*c_dims, *c_dims + nspatial_dims)
|
|
: memory::dims {};
|
|
}
|
|
|
|
/// Returns an #dnnl::engine value.
|
|
/// @param what The value to query.
|
|
/// @returns The result of the query.
|
|
/// @returns A weak handle to the engine that the primitive descriptor was
|
|
/// created with.
|
|
engine query_engine(query what) const {
|
|
dnnl_engine_t c_engine;
|
|
error::wrap_c_api(dnnl_primitive_desc_query(get(),
|
|
dnnl::convert_to_c(what), 0, &c_engine),
|
|
"could not get an engine from a primitive_desc");
|
|
return engine(c_engine, true);
|
|
}
|
|
|
|
/// Resets the value of the handle to a clone of a C API primitive
|
|
/// descriptor.
|
|
/// @param pd A C API primitive descriptor to clone.
|
|
void reset_with_clone(const_dnnl_primitive_desc_t pd) {
|
|
dnnl_primitive_desc_t new_pd;
|
|
error::wrap_c_api(dnnl_primitive_desc_clone(&new_pd, pd),
|
|
"could not clone a primitive descriptor");
|
|
reset(new_pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
|
/// primitive descriptor after verifying that it is what the caller
|
|
/// expects.
|
|
///
|
|
/// @note
|
|
/// The @p prim_kind should map to a primitive that does not have
|
|
/// different values of propagation kind (e.g. #dnnl::binary).
|
|
/// @note
|
|
/// Primitive descriptor base constructed this way does not support
|
|
/// next_impl() (will throw).
|
|
///
|
|
/// @param pd C API primitive descriptor to clone.
|
|
/// @param prim_kind Expected primitive kind.
|
|
primitive_desc_base(
|
|
dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
|
|
: primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
|
|
|
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
|
/// primitive descriptor after verifying that it is what the caller
|
|
/// expects.
|
|
///
|
|
/// @note
|
|
/// Primitive descriptor base constructed this way does not support
|
|
/// next_impl() (will throw).
|
|
///
|
|
/// @param pd C API primitive descriptor to clone.
|
|
/// @param prim_kind Expected primitive kind.
|
|
/// @param aprop_kind Expected propagation kind.
|
|
primitive_desc_base(dnnl_primitive_desc_t pd,
|
|
dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
|
|
: primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
|
|
|
|
/// Constructs a primitive descriptor base object from a clone of a C API
|
|
/// primitive descriptor after verifying that it is what the caller
|
|
/// expects.
|
|
///
|
|
/// @note
|
|
/// Primitive descriptor base constructed this way does not support
|
|
/// next_impl() (will throw).
|
|
///
|
|
/// @param pd C API primitive descriptor to clone.
|
|
/// @param prim_kind Expected primitive kind.
|
|
/// @param prop_kind1 Expected propagation kind (option 1).
|
|
/// @param prop_kind2 Expected propagation kind (option 2). This value is
|
|
/// checked if the check with @p prop_kind1 fails.
|
|
primitive_desc_base(dnnl_primitive_desc_t pd,
|
|
dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
|
|
dnnl::prop_kind prop_kind2) {
|
|
// It is OK to pass an empty primitive descriptor
|
|
if (pd == nullptr) return;
|
|
|
|
dnnl_status_t rc;
|
|
|
|
dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
|
|
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
|
|
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
|
|
|
|
// Check that primitive kind matches
|
|
dnnl_primitive_kind_t pd_kind;
|
|
rc = dnnl_primitive_desc_query(
|
|
pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
|
|
error::wrap_c_api(
|
|
rc, "could not get primitive kind from a primitive descriptor");
|
|
if (pd_kind != c_prim_kind)
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"primitive descriptor operation kind mismatch");
|
|
|
|
// Check that propagation kind matches
|
|
dnnl_prop_kind_t pd_prop_kind;
|
|
rc = dnnl_primitive_desc_query(
|
|
pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
|
|
|
|
// Something went wrong
|
|
if (rc != dnnl_success && rc != dnnl_unimplemented)
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"could not get propagation kind from the primitive "
|
|
"descriptor");
|
|
|
|
// Everything is fine
|
|
if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
|
|
|| (rc == dnnl_success
|
|
&& (pd_prop_kind == c_prop_kind1
|
|
|| pd_prop_kind == c_prop_kind2))) {
|
|
reset_with_clone(pd);
|
|
return;
|
|
}
|
|
|
|
// We could get the propagation kind but there is a mismatch
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"primitive descriptor propagation kind mismatch");
|
|
}
|
|
|
|
/// Returns a constant reference to a static instance of default constructed
|
|
/// primitive attributes
|
|
static const primitive_attr &default_attr() {
|
|
static const primitive_attr attr;
|
|
return attr;
|
|
}
|
|
|
|
const_dnnl_memory_desc_t optional_arg(const memory::desc *md) {
|
|
return md ? md->get() : nullptr;
|
|
}
|
|
|
|
const dnnl_dim_t *optional_arg(const memory::dims *dims) {
|
|
return dims ? dims->data() : nullptr;
|
|
}
|
|
|
|
const float *optional_arg(const std::vector<float> *arg) {
|
|
return arg ? arg->data() : nullptr;
|
|
}
|
|
|
|
using base = primitive_desc_base;
|
|
};
|
|
|
|
/// @} dnnl_api_primitives_common
|
|
|
|
/// @addtogroup dnnl_api_reorder Reorder
|
|
///
|
|
/// A primitive to copy data between two memory objects. This primitive is
|
|
/// typically used to change the way the data is laid out in memory.
|
|
///
|
|
/// @sa @ref dev_guide_reorder in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Reorder primitive.
|
|
struct reorder : public primitive {
|
|
/// Primitive descriptor for a reorder primitive.
|
|
struct primitive_desc : public primitive_desc_base {
|
|
using primitive_desc_base::primitive_desc_base;
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for reorder primitive.
|
|
///
|
|
/// @note
|
|
/// If @p allow_empty is true, the constructor does not throw if a
|
|
/// primitive descriptor cannot be created.
|
|
///
|
|
/// @param src_engine Engine on which the source memory object will be
|
|
/// located.
|
|
/// @param src_md Source memory descriptor.
|
|
/// @param dst_engine Engine on which the destination memory object
|
|
/// will be located.
|
|
/// @param dst_md Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is allowed
|
|
/// to fail without throwing an exception. In this case an empty
|
|
/// object will be produced. This flag is optional and defaults to
|
|
/// false.
|
|
primitive_desc(const engine &src_engine, const memory::desc &src_md,
|
|
const engine &dst_engine, const memory::desc &dst_md,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t result;
|
|
dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
|
|
src_md.get(), src_engine.get(), dst_md.get(),
|
|
dst_engine.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the reorder primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for reorder primitive.
|
|
///
|
|
/// @param src Source memory object. It is used to obtain the source
|
|
/// memory descriptor and engine.
|
|
/// @param dst Destination memory object. It is used to obtain the
|
|
/// destination memory descriptor and engine.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is allowed
|
|
/// to fail without throwing an exception. In this case an empty
|
|
/// object will be produced. This flag is optional and defaults to
|
|
/// false.
|
|
primitive_desc(const memory &src, const memory &dst,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t result;
|
|
auto src_md = src.get_desc();
|
|
auto dst_md = dst.get_desc();
|
|
dnnl_status_t status = dnnl_reorder_primitive_desc_create(&result,
|
|
src_md.get(), src.get_engine().get(), dst_md.get(),
|
|
dst.get_engine().get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the reorder primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for reorder primitive from a C
|
|
/// API primitive descriptor which must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for reorder primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: primitive_desc_base(pd, dnnl::primitive::kind::reorder) {}
|
|
|
|
/// Returns the engine on which the source memory is allocated.
|
|
/// @returns The engine on which the source memory is allocated.
|
|
engine get_src_engine() const {
|
|
return query_engine(dnnl::query::reorder_src_engine);
|
|
}
|
|
|
|
/// Returns the engine on which the destination memory is allocated.
|
|
/// @returns The engine on which the destination memory is allocated.
|
|
engine get_dst_engine() const {
|
|
return query_engine(dnnl::query::reorder_dst_engine);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
reorder() = default;
|
|
|
|
/// Constructs a reorder primitive.
|
|
/// @param pd Primitive descriptor for reorder primitive.
|
|
reorder(const primitive_desc &pd) : primitive(pd.get()) {}
|
|
|
|
/// Constructs a reorder primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for reorder primitive.
|
|
/// @param cache_blob Cache blob.
|
|
reorder(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd.get(), cache_blob) {}
|
|
|
|
/// Constructs a reorder primitive that would reorder data between memory
|
|
/// objects having the same memory descriptors as memory objects @p src and
|
|
/// @p dst.
|
|
///
|
|
/// @param src Source memory object.
|
|
/// @param dst Destination memory object.
|
|
/// @param attr Primitive attributes to use (optional).
|
|
reorder(const memory &src, const memory &dst,
|
|
const primitive_attr &attr = primitive_attr())
|
|
: primitive(primitive_desc(src, dst, attr).get()) {}
|
|
|
|
using primitive::execute;
|
|
|
|
/// Executes the reorder primitive.
|
|
///
|
|
/// @param astream Stream object. The stream must belong to the same engine
|
|
/// as the primitive.
|
|
/// @param src Source memory object.
|
|
/// @param dst Destination memory object.
|
|
void execute(const stream &astream, memory &src, memory &dst) const {
|
|
primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
|
|
}
|
|
};
|
|
|
|
/// @} dnnl_api_reorder
|
|
|
|
/// @addtogroup dnnl_api_concat Concat
|
|
///
|
|
/// A primitive to concatenate data by arbitrary dimension.
|
|
///
|
|
/// @sa @ref dev_guide_concat in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
inline std::vector<const_dnnl_memory_desc_t> convert_to_c(
|
|
const std::vector<memory::desc> &mds) {
|
|
std::vector<const_dnnl_memory_desc_t> c_mds;
|
|
c_mds.reserve(mds.size());
|
|
for (const auto &md : mds)
|
|
c_mds.push_back(md.get());
|
|
return c_mds;
|
|
}
|
|
/// @endcond
|
|
|
|
/// Tensor concatenation (concat) primitive.
|
|
struct concat : public primitive {
|
|
/// Primitive descriptor for a concat primitive.
|
|
struct primitive_desc : public primitive_desc_base {
|
|
using primitive_desc_base::primitive_desc_base;
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an out-of-place concatenation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to perform the operation on.
|
|
/// @param dst Destination memory descriptor.
|
|
/// @param concat_dimension Source tensors will be concatenated over
|
|
/// dimension with this index. Note that order of dimensions does
|
|
/// not depend on memory format.
|
|
/// @param srcs Vector of source memory descriptors.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &dst,
|
|
int concat_dimension, const std::vector<memory::desc> &srcs,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
auto c_srcs = convert_to_c(srcs);
|
|
|
|
dnnl_primitive_desc_t result;
|
|
dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
|
|
aengine.get(), dst.get(), (int)c_srcs.size(),
|
|
concat_dimension, c_srcs.data(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the concat primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for an out-of-place concatenation
|
|
/// primitive.
|
|
///
|
|
/// This version derives the destination memory descriptor
|
|
/// automatically.
|
|
///
|
|
/// @param aengine Engine to perform the operation on.
|
|
/// @param concat_dimension Source tensors will be concatenated over
|
|
/// dimension with this index. Note that order of dimensions does
|
|
/// not depend on memory format.
|
|
/// @param srcs Vector of source memory descriptors.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, int concat_dimension,
|
|
const std::vector<memory::desc> &srcs,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
auto c_api_srcs = convert_to_c(srcs);
|
|
|
|
dnnl_primitive_desc_t result;
|
|
dnnl_status_t status = dnnl_concat_primitive_desc_create(&result,
|
|
aengine.get(), nullptr, (int)c_api_srcs.size(),
|
|
concat_dimension, c_api_srcs.data(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the concat primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for concat primitive from a C
|
|
/// API primitive descriptor which must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for concat primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: primitive_desc_base(pd, dnnl::primitive::kind::concat) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
concat() = default;
|
|
|
|
/// Constructs a concatenation primitive.
|
|
/// @param pd Primitive descriptor for concatenation primitive.
|
|
concat(const primitive_desc &pd) : primitive(pd.get()) {}
|
|
|
|
/// Constructs a concatenation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for concatenation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
concat(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd.get(), cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_concat
|
|
|
|
/// @addtogroup dnnl_api_sum Sum
|
|
///
|
|
/// A primitive to sum multiple tensors.
|
|
///
|
|
/// @sa @ref dev_guide_sum in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Out-of-place summation (sum) primitive.
|
|
struct sum : public primitive {
|
|
/// Primitive descriptor for a sum primitive.
|
|
struct primitive_desc : public primitive_desc_base {
|
|
using primitive_desc_base::primitive_desc_base;
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a sum primitive.
|
|
///
|
|
/// @param aengine Engine to perform the operation on.
|
|
/// @param dst Destination memory descriptor.
|
|
/// @param scales Vector of scales to multiply data in each source
|
|
/// memory by.
|
|
/// @param srcs Vector of source memory descriptors.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &dst,
|
|
const std::vector<float> &scales,
|
|
const std::vector<memory::desc> &srcs,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
validate_container_size(scales,
|
|
"counts of scales and sources are not equal",
|
|
(int)srcs.size(), (int)srcs.size());
|
|
|
|
auto c_api_srcs = convert_to_c(srcs);
|
|
|
|
dnnl_primitive_desc_t result;
|
|
dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
|
|
aengine.get(), dst.get(), (int)c_api_srcs.size(),
|
|
scales.data(), c_api_srcs.data(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the sum primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a sum primitive.
|
|
///
|
|
/// This version derives the destination memory descriptor
|
|
/// automatically.
|
|
///
|
|
/// @param aengine Engine on which to perform the operation.
|
|
/// @param scales Vector of scales by which to multiply data in each
|
|
/// source memory object.
|
|
/// @param srcs Vector of source memory descriptors.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const std::vector<float> &scales,
|
|
const std::vector<memory::desc> &srcs,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
validate_container_size(scales,
|
|
"counts of scales and sources are not equal",
|
|
(int)srcs.size(), (int)srcs.size());
|
|
|
|
auto c_api_srcs = convert_to_c(srcs);
|
|
dnnl_primitive_desc_t result;
|
|
dnnl_status_t status = dnnl_sum_primitive_desc_create(&result,
|
|
aengine.get(), nullptr, (int)c_api_srcs.size(),
|
|
scales.data(), c_api_srcs.data(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the sum primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(status == dnnl_success ? result : dnnl_primitive_desc_t());
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for sum primitive from a C API
|
|
/// primitive descriptor which must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for sum primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: primitive_desc_base(pd, dnnl::primitive::kind::sum) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
sum() = default;
|
|
|
|
/// Constructs a sum primitive.
|
|
/// @param pd Primitive descriptor for sum primitive.
|
|
sum(const primitive_desc &pd) : primitive(pd.get()) {}
|
|
|
|
/// Constructs a sum primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for sum primitive.
|
|
/// @param cache_blob Cache blob.
|
|
sum(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd.get(), cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_sum
|
|
|
|
/// @addtogroup dnnl_api_primitives_common
|
|
/// @{
|
|
|
|
/// A base class for descriptors of all primitives that support iteration
|
|
/// over multiple implementations.
|
|
struct primitive_desc : public primitive_desc_base {
|
|
using primitive_desc_base::primitive_desc_base;
|
|
|
|
primitive_desc() = default;
|
|
|
|
/// Changes the primitive descriptor to point to the next available
|
|
/// implementation.
|
|
///
|
|
/// @returns @c true on success and @c false if the last available
|
|
/// implementation has already been reached. In the latter case, the
|
|
/// primitive descriptor itself is kept unchanged.
|
|
bool next_impl() {
|
|
dnnl_status_t status = dnnl_primitive_desc_next_impl(get());
|
|
if (status == dnnl_last_impl_reached) return false;
|
|
error::wrap_c_api(status, "last available implementation is reached");
|
|
return true;
|
|
}
|
|
};
|
|
|
|
/// @} dnnl_api_primitives_common
|
|
|
|
/// @addtogroup dnnl_api_convolution Convolution
|
|
///
|
|
/// A primitive to perform 1D, 2D or 3D convolution. Supported variants are
|
|
/// forward propagation, backward propagation, and weights gradient with or
|
|
/// without bias.
|
|
///
|
|
/// @sa @ref dev_guide_convolution in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Convolution forward propagation primitive.
|
|
struct convolution_forward : public primitive {
|
|
/// Primitive descriptor for a convolution forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a convolution forward
|
|
/// propagation primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
|
/// descriptor disables the bias term.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, &bias_desc, dst_desc, strides, nullptr,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution forward
|
|
/// propagation primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &dst_desc,
|
|
const memory::dims &strides, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, nullptr, dst_desc, strides, nullptr,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution forward
|
|
/// propagation primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
|
/// descriptor disables the bias term.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, &bias_desc, dst_desc, strides, &dilates,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution forward
|
|
/// propagation primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &dst_desc,
|
|
const memory::dims &strides, const memory::dims &dilates,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, nullptr, dst_desc, strides, &dilates,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a convolution forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// Returns the bias memory descriptor.
|
|
/// @returns The bias memory descriptor.
|
|
/// @returns A zero memory descriptor of the primitive does not have a
|
|
/// bias parameter.
|
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc *bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r, const primitive_attr &attr,
|
|
bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_convolution_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
convert_to_c(aalgorithm), src_desc.get(),
|
|
weights_desc.get(), optional_arg(bias_desc),
|
|
dst_desc.get(), &strides[0], optional_arg(dilates),
|
|
&padding_l[0], &padding_r[0], attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the convolution forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
convolution_forward() = default;
|
|
|
|
/// Constructs a convolution forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a convolution forward propagation
|
|
/// primitive.
|
|
convolution_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a convolution forward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a convolution forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
convolution_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Convolution backward propagation primitive.
|
|
struct convolution_backward_data : public primitive {
|
|
/// Primitive descriptor for a convolution backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a convolution backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
|
|
diff_dst_desc, strides, nullptr, padding_l, padding_r,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
|
|
diff_dst_desc, strides, &dilates, padding_l, padding_r,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a convolution backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_convolution_backward_data_primitive_desc_create(&pd,
|
|
aengine.get(), convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), weights_desc.get(),
|
|
diff_dst_desc.get(), &strides[0],
|
|
optional_arg(dilates), &padding_l[0], &padding_r[0],
|
|
hint_fwd_pd.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the convolution backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
convolution_backward_data() = default;
|
|
|
|
/// Constructs a convolution backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a convolution backward propagation
|
|
/// primitive.
|
|
convolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a convolution backward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a convolution backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
convolution_backward_data(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Convolution weights gradient primitive.
|
|
struct convolution_backward_weights : public primitive {
|
|
/// Primitive descriptor for a convolution weights gradient primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
|
/// primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
|
/// memory descriptor disables the bias term.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
&diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
|
/// primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
nullptr, diff_dst_desc, strides, nullptr, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution weights
|
|
/// gradient primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
|
/// memory descriptor disables the bias term.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
&diff_bias_desc, diff_dst_desc, strides, &dilates,
|
|
padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution weights
|
|
/// gradient primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Convolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd, and
|
|
/// #dnnl::algorithm::convolution_auto.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
nullptr, diff_dst_desc, strides, &dilates, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a convolution weights gradient
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a convolution weights
|
|
/// gradient primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
|
|
dnnl::prop_kind::backward_weights) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// Returns the diff bias memory descriptor.
|
|
/// @returns The diff bias memory descriptor.
|
|
/// @returns A zero memory descriptor of the primitive does not have a
|
|
/// diff bias parameter.
|
|
memory::desc diff_bias_desc() const {
|
|
return base::diff_weights_desc(1);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc *diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const convolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_convolution_backward_weights_primitive_desc_create(
|
|
&pd, aengine.get(), convert_to_c(aalgorithm),
|
|
src_desc.get(), diff_weights_desc.get(),
|
|
optional_arg(diff_bias_desc), diff_dst_desc.get(),
|
|
&strides[0], optional_arg(dilates), &padding_l[0],
|
|
&padding_r[0], hint_fwd_pd.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the convolution weights update primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
convolution_backward_weights() = default;
|
|
|
|
/// Constructs a convolution weights gradient primitive.
|
|
/// @param pd Primitive descriptor for a convolution weights gradient
|
|
/// primitive.
|
|
convolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a convolution weights gradient primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a convolution weights gradient
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
convolution_backward_weights(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_convolution
|
|
//
|
|
/// @addtogroup dnnl_api_deconvolution Deconvolution
|
|
///
|
|
/// A primitive to perform 1D, 2D or 3D deconvolution. Supported variants are
|
|
/// forward propagation, backward propagation, and weights gradient with or
|
|
/// without bias.
|
|
///
|
|
/// @{
|
|
|
|
/// Deconvolution forward propagation primitive.
|
|
struct deconvolution_forward : public primitive {
|
|
/// Primitive descriptor for a deconvolution forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution forward
|
|
/// propagation primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Deconvolution algorithm:
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
|
/// descriptor disables the bias term.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, &bias_desc, dst_desc, strides, nullptr,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution forward
|
|
/// propagation primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Deconvolution algorithm:
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &dst_desc,
|
|
const memory::dims &strides, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, nullptr, dst_desc, strides, nullptr,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution forward
|
|
/// propagation primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Deconvolution algorithm:
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param bias_desc Bias memory descriptor. Passing zero memory
|
|
/// descriptor disables the bias term.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, &bias_desc, dst_desc, strides, &dilates,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution forward
|
|
/// propagation primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Deconvolution algorithm:
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &dst_desc,
|
|
const memory::dims &strides, const memory::dims &dilates,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
weights_desc, nullptr, dst_desc, strides, &dilates,
|
|
padding_l, padding_r, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a deconvolution forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc *bias_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r, const primitive_attr &attr,
|
|
bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_deconvolution_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
convert_to_c(aalgorithm), src_desc.get(),
|
|
weights_desc.get(), optional_arg(bias_desc),
|
|
dst_desc.get(), &strides[0], optional_arg(dilates),
|
|
&padding_l[0], &padding_r[0], attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the deconvolution forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
deconvolution_forward() = default;
|
|
|
|
/// Constructs a deconvolution forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a deconvolution forward propagation
|
|
/// primitive.
|
|
deconvolution_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a deconvolution forward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a deconvolution forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
deconvolution_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Deconvolution backward propagation primitive.
|
|
struct deconvolution_backward_data : public primitive {
|
|
/// Primitive descriptor for a deconvolution backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm
|
|
/// (#dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
|
|
diff_dst_desc, strides, nullptr, padding_l, padding_r,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm
|
|
/// (#dnnl::algorithm::convolution_direct,
|
|
/// #dnnl::algorithm::convolution_winograd).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param weights_desc Weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, weights_desc,
|
|
diff_dst_desc, strides, &dilates, padding_l, padding_r,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a deconvolution backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, diff_src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_deconvolution_backward_data_primitive_desc_create(
|
|
&pd, aengine.get(), convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), weights_desc.get(),
|
|
diff_dst_desc.get(), &strides[0],
|
|
optional_arg(dilates), &padding_l[0], &padding_r[0],
|
|
hint_fwd_pd.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the deconvolution backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
deconvolution_backward_data() = default;
|
|
|
|
/// Constructs a deconvolution backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a deconvolution backward propagation
|
|
/// primitive.
|
|
deconvolution_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a deconvolution backward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a deconvolution backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
deconvolution_backward_data(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Deconvolution weights gradient primitive.
|
|
struct deconvolution_backward_weights : public primitive {
|
|
/// Primitive descriptor for a deconvolution weights gradient primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution weights
|
|
/// gradient primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
|
/// memory descriptor disables the bias term.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
&diff_bias_desc, diff_dst_desc, strides, nullptr, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution weights
|
|
/// gradient primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p padding_l, and @p padding_r contain values
|
|
/// for spatial dimensions only and hence must have the same number of
|
|
/// elements as there are spatial dimensions. The order of values is
|
|
/// the same as in the tensor: depth (for 3D tensors), height (for 3D
|
|
/// and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
nullptr, diff_dst_desc, strides, nullptr, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution weights
|
|
/// gradient primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_bias_desc Diff bias memory descriptor. Passing zero
|
|
/// memory descriptor disables the bias term.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
&diff_bias_desc, diff_dst_desc, strides, &dilates,
|
|
padding_l, padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution weights
|
|
/// gradient primitive without bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// Arrays @p strides, @p dilates, @p padding_l, and @p padding_r
|
|
/// contain values for spatial dimensions only and hence must have the
|
|
/// same number of elements as there are spatial dimensions. The order
|
|
/// of values is the same as in the tensor: depth (for 3D tensors),
|
|
/// height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Deconvolution algorithm. Possible values are
|
|
/// #dnnl::algorithm::deconvolution_direct, and
|
|
/// #dnnl::algorithm::deconvolution_winograd.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_weights_desc Diff weights memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Strides for each spatial dimension.
|
|
/// @param dilates Dilations for each spatial dimension. A zero value
|
|
/// means no dilation in the corresponding dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a deconvolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, src_desc, diff_weights_desc,
|
|
nullptr, diff_dst_desc, strides, &dilates, padding_l,
|
|
padding_r, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a deconvolution weights
|
|
/// gradient primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a deconvolution weights
|
|
/// gradient primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
|
|
dnnl::prop_kind::backward_weights) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return base::diff_weights_desc(1);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc *diff_bias_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims *dilates, const memory::dims &padding_l,
|
|
const memory::dims &padding_r,
|
|
const deconvolution_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
memory::validate_dims(strides, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
|
|
|
|
if (dilates)
|
|
memory::validate_dims(*dilates, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_deconvolution_backward_weights_primitive_desc_create(
|
|
&pd, aengine.get(), convert_to_c(aalgorithm),
|
|
src_desc.get(), diff_weights_desc.get(),
|
|
optional_arg(diff_bias_desc), diff_dst_desc.get(),
|
|
&strides[0], optional_arg(dilates), &padding_l[0],
|
|
&padding_r[0], hint_fwd_pd.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the deconvolution weights update primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
deconvolution_backward_weights() = default;
|
|
|
|
/// Constructs a deconvolution weights gradient primitive.
|
|
/// @param pd Primitive descriptor for a deconvolution weights gradient
|
|
/// primitive.
|
|
deconvolution_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a deconvolution weights gradient primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a deconvolution weights gradient
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
deconvolution_backward_weights(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_deconvolution
|
|
|
|
/// @addtogroup dnnl_api_lrn LRN
|
|
///
|
|
/// A primitive to perform local response normalization (LRN) across or within
|
|
/// channels.
|
|
///
|
|
/// @sa @ref dev_guide_lrn in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Local response normalization (LRN) forward propagation primitive.
|
|
struct lrn_forward : public primitive {
|
|
/// Primitive descriptor for an LRN forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an LRN forward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm LRN algorithm kind: either
|
|
/// #dnnl::algorithm::lrn_across_channels, or
|
|
/// #dnnl::algorithm::lrn_within_channel.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param local_size Regularization local size.
|
|
/// @param alpha The alpha regularization parameter.
|
|
/// @param beta The beta regularization parameter.
|
|
/// @param k The k regularization parameter.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, memory::dim local_size,
|
|
float alpha, float beta, float k,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_lrn_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
|
|
local_size, alpha, beta, k, attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the lrn forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for an LRN forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LRN forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_local_size()const
|
|
memory::dim get_local_size() const { return base::get_local_size(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_k()const
|
|
float get_k() const { return base::get_k(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lrn_forward() = default;
|
|
|
|
/// Constructs an LRN forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LRN forward propagation
|
|
/// primitive.
|
|
lrn_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LRN forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LRN forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lrn_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Local response normalization (LRN) backward propagation primitive.
|
|
struct lrn_backward : public primitive {
|
|
/// Primitive descriptor for an LRN backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an LRN backward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm LRN algorithm kind: either
|
|
/// #dnnl::algorithm::lrn_across_channels, or
|
|
/// #dnnl::algorithm::lrn_within_channel.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param local_size Regularization local size.
|
|
/// @param alpha The alpha regularization parameter.
|
|
/// @param beta The beta regularization parameter.
|
|
/// @param k The k regularization parameter.
|
|
/// @param hint_fwd_pd Primitive descriptor for an LRN forward
|
|
/// propagation primitive. It is used as a hint for deciding which
|
|
/// memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
memory::dim local_size, float alpha, float beta, float k,
|
|
const lrn_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_lrn_backward_primitive_desc_create(&pd,
|
|
aengine.get(), convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), diff_dst_desc.get(), src_desc.get(),
|
|
local_size, alpha, beta, k, hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the lrn backward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for an LRN backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LRN backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_local_size()const
|
|
memory::dim get_local_size() const { return base::get_local_size(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_k()const
|
|
float get_k() const { return base::get_k(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lrn_backward() = default;
|
|
|
|
/// Constructs an LRN backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LRN backward propagation
|
|
/// primitive.
|
|
lrn_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LRN backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LRN backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lrn_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_lrn
|
|
|
|
/// @addtogroup dnnl_api_eltwise Eltwise
|
|
///
|
|
/// A primitive to perform elementwise operations such as the
|
|
/// rectifier linear unit (ReLU).
|
|
///
|
|
/// Both forward and backward propagation primitives support in-place
|
|
/// operation; that is, src and dst can refer to the same memory for forward
|
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
|
/// backward propagation.
|
|
///
|
|
/// @warning
|
|
/// Because the original source data is required for backward propagation,
|
|
/// in-place forward propagation is not generally supported in the
|
|
/// training mode. However, for algorithms supporting destination as input
|
|
/// memory, dst can be used for the backward propagation, which makes it
|
|
/// possible to get performance benefit even in the training mode.
|
|
///
|
|
/// @sa @ref dev_guide_eltwise in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Elementwise unary operation forward propagation primitive.
|
|
struct eltwise_forward : public primitive {
|
|
/// Primitive descriptor for an elementwise forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an elementwise forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
dst_desc, nullptr, nullptr, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an elementwise forward
|
|
/// propagation primitive with an alpha parameter.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param alpha The alpha parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, float alpha,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
dst_desc, &alpha, nullptr, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an elementwise forward
|
|
/// propagation primitive with an alpha and beta parameters.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param alpha The alpha parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param beta The beta parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, float alpha, float beta,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, src_desc,
|
|
dst_desc, &alpha, &beta, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an eltwise forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an eltwise forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, const float *alpha,
|
|
const float *beta, const primitive_attr &attr,
|
|
bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_eltwise_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(aalgorithm), src_desc.get(),
|
|
dst_desc.get(), alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the eltwise forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
eltwise_forward() = default;
|
|
|
|
/// Constructs an eltwise forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an eltwise forward propagation
|
|
/// primitive.
|
|
eltwise_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an eltwise forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an eltwise forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
eltwise_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Elementwise unary operation backward propagation primitive.
|
|
struct eltwise_backward : public primitive {
|
|
/// Primitive descriptor for eltwise backward propagation.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an elementwise backward
|
|
/// propagation primitive with an alpha parameter.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param data_desc Destination memory descriptor if one of the
|
|
/// "use_dst_for_bwd" algorithms are used (such as
|
|
/// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
|
|
/// otherwise.
|
|
/// @param hint_fwd_pd Primitive descriptor for an elementwise
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const memory::desc &data_desc,
|
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
|
|
data_desc, nullptr, nullptr, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an elementwise backward
|
|
/// propagation primitive with an alpha parameter.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param data_desc Destination memory descriptor if one of the
|
|
/// "use_dst_for_bwd" algorithms are used (such as
|
|
/// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
|
|
/// otherwise.
|
|
/// @param alpha The alpha parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param hint_fwd_pd Primitive descriptor for an elementwise
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const memory::desc &data_desc, float alpha,
|
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
|
|
data_desc, &alpha, nullptr, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an elementwise backward
|
|
/// propagation primitive with an alpha and beta parameters.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Elementwise algorithm kind.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param data_desc Destination memory descriptor if one of the
|
|
/// "use_dst_for_bwd" algorithms are used (such as
|
|
/// #dnnl_eltwise_relu_use_dst_for_bwd), source memory descriptor
|
|
/// otherwise.
|
|
/// @param alpha The alpha parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param beta The beta parameter for the elementwise operation.
|
|
/// Specific meaning depends on the algorithm.
|
|
/// @param hint_fwd_pd Primitive descriptor for an elementwise
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const memory::desc &data_desc, float alpha, float beta,
|
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, diff_src_desc, diff_dst_desc,
|
|
data_desc, &alpha, &beta, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an eltwise backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an eltwise backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const memory::desc &data_desc, const float *alpha,
|
|
const float *beta,
|
|
const eltwise_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_eltwise_backward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), diff_dst_desc.get(), data_desc.get(),
|
|
alpha ? *alpha : 0.0f, beta ? *beta : 0.0f,
|
|
hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the eltwise backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
eltwise_backward() = default;
|
|
|
|
/// Constructs an eltwise backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an eltwise backward propagation
|
|
/// primitive.
|
|
eltwise_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an eltwise backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an eltwise backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
eltwise_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_eltwise
|
|
|
|
/// @addtogroup dnnl_api_softmax Softmax
|
|
///
|
|
/// A primitive to perform softmax.
|
|
///
|
|
/// @sa @ref dev_guide_softmax in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Softmax forward propagation primitive.
|
|
struct softmax_forward : public primitive {
|
|
/// Primitive descriptor for a softmax forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a softmax forward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Softmax algorithm kind: either
|
|
/// #dnnl::algorithm::softmax_accurate,
|
|
/// or #dnnl::algorithm::softmax_log.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param axis Axis over which softmax is computed.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, int axis,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_softmax_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(aalgorithm), src_desc.get(),
|
|
dst_desc.get(), axis, attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the softmax forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a softmax forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a softmax forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_axis()const
|
|
int get_axis() const { return base::get_axis(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
softmax_forward() = default;
|
|
|
|
/// Constructs a softmax forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a softmax forward propagation
|
|
/// primitive.
|
|
softmax_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a softmax forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a softmax forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
softmax_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Softmax backward propagation primitive.
|
|
struct softmax_backward : public primitive {
|
|
/// Primitive descriptor for a softmax backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a softmax backward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Softmax algorithm kind: either
|
|
/// #dnnl::algorithm::softmax_accurate,
|
|
/// or #dnnl::algorithm::softmax_log.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param axis Axis over which softmax is computed.
|
|
/// @param hint_fwd_pd Primitive descriptor for a softmax
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &dst_desc,
|
|
int axis, const softmax_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_softmax_backward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), diff_dst_desc.get(), dst_desc.get(),
|
|
axis, hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the softmax backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a softmax backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a softmax backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
dnnl::algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_axis()const
|
|
int get_axis() const { return base::get_axis(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
softmax_backward() = default;
|
|
|
|
/// Constructs a softmax backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a softmax backward propagation
|
|
/// primitive.
|
|
softmax_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a softmax backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a softmax backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
softmax_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_softmax
|
|
|
|
/// @addtogroup dnnl_api_batch_normalization Batch Normalization
|
|
///
|
|
/// A primitive to perform batch normalization.
|
|
///
|
|
/// Both forward and backward propagation primitives support in-place
|
|
/// operation; that is, src and dst can refer to the same memory for forward
|
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
|
/// backward propagation.
|
|
///
|
|
/// The batch normalization primitives computations can be controlled by
|
|
/// specifying different @ref dnnl::normalization_flags values. For example,
|
|
/// batch normalization forward propagation can be configured to either
|
|
/// compute the mean and variance or take them as arguments. It can either
|
|
/// perform scaling and shifting using gamma and beta parameters or not.
|
|
/// Optionally, it can also perform a fused ReLU, which in case of training
|
|
/// would also require a workspace.
|
|
///
|
|
/// @sa @ref dev_guide_batch_normalization in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Batch normalization forward propagation primitive.
|
|
struct batch_normalization_forward : public primitive {
|
|
/// Primitive descriptor for a batch normalization forward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a batch normalization forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// In-place operation is supported: the dst can refer to the same
|
|
/// memory as the src.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param epsilon Batch normalization epsilon parameter.
|
|
/// @param flags Batch normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
float epsilon, normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_batch_normalization_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), dst_desc.get(), epsilon,
|
|
convert_to_c(flags), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the batch normalization forward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a batch normalization
|
|
/// forward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a batch normalization
|
|
/// forward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::batch_normalization,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// Returns memory descriptor for mean.
|
|
/// @returns Memory descriptor for mean.
|
|
memory::desc mean_desc() const { return stat_desc(mean); }
|
|
|
|
/// Returns memory descriptor for variance.
|
|
/// @returns Memory descriptor for variance.
|
|
memory::desc variance_desc() const { return stat_desc(var); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
|
|
private:
|
|
enum {
|
|
mean = 1,
|
|
var = 2,
|
|
};
|
|
memory::desc stat_desc(int kind) const {
|
|
const bool use_global_stats
|
|
= (get_flags() & normalization_flags::use_global_stats)
|
|
!= normalization_flags::none;
|
|
return query_md(
|
|
use_global_stats ? query::src_md : query::dst_md, kind);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
batch_normalization_forward() = default;
|
|
|
|
/// Constructs a batch normalization forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a batch normalization forward
|
|
/// propagation primitive.
|
|
batch_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a batch normalization forward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a batch normalization forward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
batch_normalization_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Batch normalization backward propagation primitive.
|
|
struct batch_normalization_backward : public primitive {
|
|
/// Primitive descriptor for a batch normalization backward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a batch normalization backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param epsilon Batch normalization epsilon parameter.
|
|
/// @param flags Batch normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param hint_fwd_pd Primitive descriptor for a batch normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
float epsilon, normalization_flags flags,
|
|
const batch_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_batch_normalization_backward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
diff_src_desc.get(), diff_dst_desc.get(),
|
|
src_desc.get(), epsilon, convert_to_c(flags),
|
|
hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the batch normalization backward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a batch normalization
|
|
/// backward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a batch normalization
|
|
/// backward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::batch_normalization,
|
|
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
|
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
|
memory::desc variance_desc() const {
|
|
return query_md(query::src_md, 2);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
batch_normalization_backward() = default;
|
|
|
|
/// Constructs a batch normalization backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a batch normalization backward
|
|
/// propagation primitive.
|
|
batch_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a batch normalization backward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a batch normalization backward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
batch_normalization_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_batch_normalization
|
|
|
|
/// @addtogroup dnnl_api_group_normalization Group Normalization
|
|
///
|
|
/// A primitive to perform group normalization.
|
|
///
|
|
/// Both forward and backward propagation primitives support in-place
|
|
/// operation; that is, src and dst can refer to the same memory for forward
|
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
|
/// backward propagation.
|
|
///
|
|
/// The group normalization primitives computations can be controlled by
|
|
/// specifying different @ref dnnl::normalization_flags values. For example,
|
|
/// group normalization forward propagation can be configured to either
|
|
/// compute the mean and variance or take them as arguments. It can either
|
|
/// perform scaling and shifting using gamma and beta parameters or not.
|
|
///
|
|
/// @sa @ref dev_guide_group_normalization in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Group normalization forward propagation primitive.
|
|
struct group_normalization_forward : public primitive {
|
|
/// Primitive descriptor for a group normalization forward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a group normalization forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// In-place operation is supported: the dst can refer to the same
|
|
/// memory as the src.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param groups Group normalization groups parameter.
|
|
/// @param epsilon Group normalization epsilon parameter.
|
|
/// @param flags Group normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
memory::dim groups, float epsilon, normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_group_normalization_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), dst_desc.get(), groups, epsilon,
|
|
convert_to_c(flags), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the group normalization forward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a group normalization
|
|
/// forward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a group normalization
|
|
/// forward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::group_normalization,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// Returns memory descriptor for mean.
|
|
/// @returns Memory descriptor for mean.
|
|
memory::desc mean_desc() const { return stat_desc(mean); }
|
|
|
|
/// Returns memory descriptor for variance.
|
|
/// @returns Memory descriptor for variance.
|
|
memory::desc variance_desc() const { return stat_desc(var); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_group_size()const
|
|
memory::dim get_group_size() const { return base::get_group_size(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
|
|
private:
|
|
enum {
|
|
mean = 1,
|
|
var = 2,
|
|
};
|
|
memory::desc stat_desc(int kind) const {
|
|
const bool use_global_stats
|
|
= (get_flags() & normalization_flags::use_global_stats)
|
|
!= normalization_flags::none;
|
|
return query_md(
|
|
use_global_stats ? query::src_md : query::dst_md, kind);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
group_normalization_forward() = default;
|
|
|
|
/// Constructs a group normalization forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a group normalization forward
|
|
/// propagation primitive.
|
|
group_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a group normalization forward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a group normalization forward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
group_normalization_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Group normalization backward propagation primitive.
|
|
struct group_normalization_backward : public primitive {
|
|
/// Primitive descriptor for a group normalization backward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a group normalization backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param groups Group normalization groups parameter.
|
|
/// @param epsilon Group normalization epsilon parameter.
|
|
/// @param flags Group normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param hint_fwd_pd Primitive descriptor for a group normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
memory::dim groups, float epsilon, normalization_flags flags,
|
|
const group_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_group_normalization_backward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
diff_src_desc.get(), diff_dst_desc.get(),
|
|
src_desc.get(), groups, epsilon,
|
|
convert_to_c(flags), hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the group normalization backward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a group normalization
|
|
/// backward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a group normalization
|
|
/// backward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::group_normalization,
|
|
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::group_normalization_forward::primitive_desc::mean_desc()const
|
|
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
|
|
|
|
/// @copydoc dnnl::group_normalization_forward::primitive_desc::variance_desc()const
|
|
memory::desc variance_desc() const {
|
|
return query_md(query::src_md, 2);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_group_size()const
|
|
memory::dim get_group_size() const { return base::get_group_size(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
group_normalization_backward() = default;
|
|
|
|
/// Constructs a group normalization backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a group normalization backward
|
|
/// propagation primitive.
|
|
group_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a group normalization backward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a group normalization backward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
group_normalization_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_group_normalization
|
|
|
|
/// @addtogroup dnnl_api_layer_normalization Layer Normalization
|
|
///
|
|
/// A primitive to perform layer normalization. Normalization is performed
|
|
/// within the last logical dimension of data tensor.
|
|
///
|
|
/// Both forward and backward propagation primitives support in-place
|
|
/// operation; that is, src and dst can refer to the same memory for forward
|
|
/// propagation, and diff_dst and diff_src can refer to the same memory for
|
|
/// backward propagation.
|
|
///
|
|
/// The layer normalization primitives computations can be controlled by
|
|
/// specifying different @ref dnnl::normalization_flags values. For example,
|
|
/// layer normalization forward propagation can be configured to either
|
|
/// compute the mean and variance or take them as arguments. It can either
|
|
/// perform scaling and shifting using gamma and beta parameters or not.
|
|
///
|
|
/// @sa @ref dev_guide_layer_normalization in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Layer normalization forward propagation primitive.
|
|
struct layer_normalization_forward : public primitive {
|
|
/// Primitive descriptor for a layer normalization forward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param stat_desc Statistics memory descriptors.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
const memory::desc &stat_desc, float epsilon,
|
|
normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
|
|
&stat_desc, memory::data_type::f32, epsilon, flags, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
float epsilon, normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
|
|
memory::data_type::f32, epsilon, flags, attr, allow_empty) {
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization forward
|
|
/// propagation primitive with a user-provided data type for the scale
|
|
/// and shift memory objects.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param stat_desc Statistics memory descriptors.
|
|
/// @param scale_shift_data_type Data type of scale and shift memory.
|
|
/// If neither scale nor shift flag are specified the parameter
|
|
/// is ignored.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
const memory::desc &stat_desc,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, dst_desc,
|
|
&stat_desc, scale_shift_data_type, epsilon, flags, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization forward
|
|
/// propagation primitive with a user-provided data type for the scale
|
|
/// and shift memory objects.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param scale_shift_data_type Data type of scale and shift memory.
|
|
/// If neither scale nor shift flag are specified the parameter
|
|
/// is ignored.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, dst_desc, nullptr,
|
|
scale_shift_data_type, epsilon, flags, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization
|
|
/// forward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a layer normalization
|
|
/// forward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::layer_normalization,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
|
memory::desc mean_desc() const { return stat_desc(mean); }
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
|
memory::desc variance_desc() const { return stat_desc(var); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
|
|
private:
|
|
enum {
|
|
mean = 1,
|
|
var = 2,
|
|
};
|
|
memory::desc stat_desc(int kind) const {
|
|
const bool use_global_stats
|
|
= (get_flags() & normalization_flags::use_global_stats)
|
|
!= normalization_flags::none;
|
|
return query_md(
|
|
use_global_stats ? query::src_md : query::dst_md, kind);
|
|
}
|
|
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
const memory::desc *stat_desc,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags, const primitive_attr &attr,
|
|
bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_layer_normalization_forward_primitive_desc_create_v2(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), dst_desc.get(),
|
|
optional_arg(stat_desc),
|
|
memory::convert_to_c(scale_shift_data_type),
|
|
epsilon, convert_to_c(flags), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the layer normalization forward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
layer_normalization_forward() = default;
|
|
|
|
/// Constructs a layer normalization forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a layer normalization forward
|
|
/// propagation primitive.
|
|
layer_normalization_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a layer normalization forward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a layer normalization forward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
layer_normalization_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Layer normalization backward propagation primitive.
|
|
struct layer_normalization_backward : public primitive {
|
|
/// Primitive descriptor for a layer normalization backward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param stat_desc Statistics memory descriptors.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
const memory::desc &stat_desc, float epsilon,
|
|
normalization_flags flags,
|
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
|
|
src_desc, &stat_desc, memory::data_type::f32,
|
|
memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
float epsilon, normalization_flags flags,
|
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
|
|
src_desc, nullptr, memory::data_type::f32,
|
|
memory::data_type::f32, epsilon, flags, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization backward
|
|
/// propagation primitive with a user-provided data type for the scale
|
|
/// and shift memory objects.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param stat_desc Statistics memory descriptors.
|
|
/// @param diff_scale_shift_data_type Data type of diff scale and shift
|
|
/// memory. If neither scale nor shift flag are specified the
|
|
/// parameter is ignored.
|
|
/// @param scale_shift_data_type Data type of scale and shift memory.
|
|
/// If neither scale nor shift flag are specified the parameter
|
|
/// is ignored.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
const memory::desc &stat_desc,
|
|
memory::data_type diff_scale_shift_data_type,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags,
|
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
|
|
src_desc, &stat_desc, diff_scale_shift_data_type,
|
|
scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization backward
|
|
/// propagation primitive with a user-provided data type for the scale
|
|
/// and shift memory objects.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::backward_data and #dnnl::prop_kind::backward
|
|
/// (diffs for all parameters are computed in this case).
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param diff_scale_shift_data_type Data type of diff scale and shift
|
|
/// memory. If neither scale nor shift flag are specified the
|
|
/// parameter is ignored.
|
|
/// @param scale_shift_data_type Data type of scale and shift memory.
|
|
/// If neither scale nor shift flag are specified the parameter
|
|
/// is ignored.
|
|
/// @param epsilon Layer normalization epsilon parameter.
|
|
/// @param flags Layer normalization flags (@ref
|
|
/// dnnl::normalization_flags).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param hint_fwd_pd Primitive descriptor for a layer normalization
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
memory::data_type diff_scale_shift_data_type,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags,
|
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, diff_src_desc, diff_dst_desc,
|
|
src_desc, nullptr, diff_scale_shift_data_type,
|
|
scale_shift_data_type, epsilon, flags, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a layer normalization
|
|
/// backward propagation primitive from a C API primitive descriptor
|
|
/// that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a layer normalization
|
|
/// backward propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd,
|
|
dnnl::primitive::kind::layer_normalization,
|
|
dnnl::prop_kind::backward, dnnl::prop_kind::backward_data) {
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::mean_desc()const
|
|
memory::desc mean_desc() const { return query_md(query::src_md, 1); }
|
|
|
|
/// @copydoc dnnl::batch_normalization_forward::primitive_desc::variance_desc()const
|
|
memory::desc variance_desc() const {
|
|
return query_md(query::src_md, 2);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
dnnl::prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// Returns normalization flags.
|
|
/// @return Normalization flags.
|
|
normalization_flags get_flags() const {
|
|
return base::get_flags<normalization_flags>();
|
|
}
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::desc &src_desc,
|
|
const memory::desc *stat_desc,
|
|
memory::data_type diff_scale_shift_data_type,
|
|
memory::data_type scale_shift_data_type, float epsilon,
|
|
normalization_flags flags,
|
|
const layer_normalization_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_layer_normalization_backward_primitive_desc_create_v2(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
diff_src_desc.get(), diff_dst_desc.get(),
|
|
src_desc.get(), optional_arg(stat_desc),
|
|
memory::convert_to_c(diff_scale_shift_data_type),
|
|
memory::convert_to_c(scale_shift_data_type),
|
|
epsilon, convert_to_c(flags), hint_fwd_pd.get(),
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the layer normalization backward propagation "
|
|
"primitive. Run workload with environment variable "
|
|
"ONEDNN_VERBOSE=all to get additional diagnostic "
|
|
"information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
layer_normalization_backward() = default;
|
|
|
|
/// Constructs a layer normalization backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a layer normalization backward
|
|
/// propagation primitive.
|
|
layer_normalization_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a layer normalization backward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a layer normalization backward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
layer_normalization_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_layer_normalization
|
|
|
|
/// @addtogroup dnnl_api_inner_product Inner Product
|
|
///
|
|
/// A primitive to compute an inner product.
|
|
///
|
|
/// @sa @ref dev_guide_inner_product in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Inner product forward propagation primitive.
|
|
struct inner_product_forward : public primitive {
|
|
/// Primitive descriptor for an inner product forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an inner product forward
|
|
/// propagation primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Memory descriptor for src.
|
|
/// @param weights_desc Memory descriptor for weights.
|
|
/// @param bias_desc Memory descriptor for bias.
|
|
/// @param dst_desc Memory descriptor for dst.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
|
const memory::desc &bias_desc, const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
|
|
&bias_desc, dst_desc, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an inner product forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Memory descriptor for src.
|
|
/// @param weights_desc Memory descriptor for weights.
|
|
/// @param dst_desc Memory descriptor for dst.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
|
const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, src_desc, weights_desc,
|
|
nullptr, dst_desc, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an inner product forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an inner product forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
|
memory::desc bias_desc() const { return base::weights_desc(1); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &weights_desc,
|
|
const memory::desc *bias_desc, const memory::desc &dst_desc,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_inner_product_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), weights_desc.get(),
|
|
optional_arg(bias_desc), dst_desc.get(),
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the inner product forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
inner_product_forward() = default;
|
|
|
|
/// Constructs an inner product forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an inner product forward
|
|
/// propagation primitive.
|
|
inner_product_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an inner product forward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for an inner product forward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
inner_product_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Inner product backward propagation primitive.
|
|
struct inner_product_backward_data : public primitive {
|
|
/// Primitive descriptor for an inner product backward propagation
|
|
/// primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an inner product backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param diff_src_desc Memory descriptor for diff src.
|
|
/// @param weights_desc Memory descriptor for weights.
|
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
|
|
const memory::desc &weights_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_inner_product_backward_data_primitive_desc_create(
|
|
&pd, aengine.get(), diff_src_desc.get(),
|
|
weights_desc.get(), diff_dst_desc.get(),
|
|
hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the inner product backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for an inner product backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an inner product backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const { return base::weights_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
inner_product_backward_data() = default;
|
|
|
|
/// Constructs an inner product backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an inner product backward
|
|
/// propagation primitive.
|
|
inner_product_backward_data(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an inner product backward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for an inner product backward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
inner_product_backward_data(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Inner product weights gradient primitive.
|
|
struct inner_product_backward_weights : public primitive {
|
|
/// Primitive descriptor for an inner product weights gradient primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an inner product weights
|
|
/// update primitive with bias.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param src_desc Memory descriptor for src.
|
|
/// @param diff_weights_desc Memory descriptor for diff weights.
|
|
/// @param diff_bias_desc Memory descriptor for diff bias.
|
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, src_desc, diff_weights_desc,
|
|
&diff_bias_desc, diff_dst_desc, hint_fwd_pd, attr,
|
|
allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an inner product weights
|
|
/// update primitive.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param src_desc Memory descriptor for src.
|
|
/// @param diff_weights_desc Memory descriptor for diff weights.
|
|
/// @param diff_dst_desc Memory descriptor for diff dst.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param hint_fwd_pd Primitive descriptor for an inner product
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, src_desc, diff_weights_desc, nullptr,
|
|
diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an inner product weights
|
|
/// update primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an inner product weights
|
|
/// gradient primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
|
|
dnnl::prop_kind::backward_weights) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_weights_desc()const
|
|
memory::desc diff_weights_desc() const {
|
|
return base::diff_weights_desc(0);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::convolution_backward_weights::primitive_desc::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return base::diff_weights_desc(1);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc *diff_bias_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const inner_product_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_inner_product_backward_weights_primitive_desc_create(
|
|
&pd, aengine.get(), src_desc.get(),
|
|
diff_weights_desc.get(),
|
|
optional_arg(diff_bias_desc), diff_dst_desc.get(),
|
|
hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the inner product weights gradient primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
inner_product_backward_weights() = default;
|
|
|
|
/// Constructs an inner product weights gradient primitive.
|
|
/// @param pd Primitive descriptor for an inner product weights gradient
|
|
/// primitive.
|
|
inner_product_backward_weights(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an inner product weights gradient primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for an inner product weights gradient
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
inner_product_backward_weights(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_inner_product
|
|
|
|
/// @addtogroup dnnl_api_rnn RNN
|
|
///
|
|
/// A primitive to compute recurrent neural network layers.
|
|
///
|
|
/// @sa @ref dev_guide_rnn in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Base class for primitive descriptors for RNN primitives.
|
|
struct rnn_primitive_desc_base : public primitive_desc {
|
|
using primitive_desc::primitive_desc;
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
rnn_primitive_desc_base() = default;
|
|
|
|
/// Constructs an RNN primitive descriptor base from a C API primitive
|
|
/// descriptor while checking that it actually describes the expected
|
|
/// primitive by comparing propagation and primitive kinds.
|
|
///
|
|
/// @param pd C API primitive descriptor.
|
|
/// @param aprop_kind Expected propagation kind.
|
|
/// @param cell_kind Expected cell kind.
|
|
rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
|
|
dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
|
|
: rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
|
|
|
|
/// Returns source layer memory descriptor.
|
|
/// @returns Source layer memory descriptor.
|
|
memory::desc src_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_LAYER);
|
|
}
|
|
|
|
/// Returns AUGRU attention memory descriptor.
|
|
/// @returns AUGRU attention memory descriptor.
|
|
memory::desc augru_attention_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_AUGRU_ATTENTION);
|
|
}
|
|
|
|
/// Returns source iteration memory descriptor.
|
|
/// @returns Source iteration memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// source iteration parameter.
|
|
memory::desc src_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER);
|
|
}
|
|
|
|
/// Returns source recurrent cell state memory descriptor.
|
|
/// @returns Source recurrent cell state memory descriptor.
|
|
memory::desc src_iter_c_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_SRC_ITER_C);
|
|
}
|
|
|
|
/// Returns weights layer memory descriptor.
|
|
/// @returns Weights layer memory descriptor.
|
|
memory::desc weights_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_LAYER);
|
|
}
|
|
|
|
/// Returns weights iteration memory descriptor.
|
|
/// @returns Weights iteration memory descriptor.
|
|
memory::desc weights_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_ITER);
|
|
}
|
|
|
|
/// Returns weights peephole memory descriptor.
|
|
/// @returns Weights peephole memory descriptor.
|
|
memory::desc weights_peephole_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PEEPHOLE);
|
|
}
|
|
|
|
/// Returns weights projection memory descriptor.
|
|
/// @returns Weights projection memory descriptor.
|
|
memory::desc weights_projection_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS_PROJECTION);
|
|
}
|
|
|
|
/// Returns bias memory descriptor.
|
|
/// @returns Bias memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// bias parameter.
|
|
memory::desc bias_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_BIAS);
|
|
}
|
|
|
|
/// Returns destination layer memory descriptor.
|
|
/// @returns Destination layer memory descriptor.
|
|
memory::desc dst_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_LAYER);
|
|
}
|
|
|
|
/// Returns destination iteration memory descriptor.
|
|
/// @returns Destination iteration memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// destination iteration parameter.
|
|
memory::desc dst_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER);
|
|
}
|
|
|
|
/// Returns destination recurrent cell state memory descriptor.
|
|
/// @returns Destination recurrent cell state memory descriptor.
|
|
memory::desc dst_iter_c_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DST_ITER_C);
|
|
}
|
|
|
|
/// Returns diff source layer memory descriptor.
|
|
/// @returns Diff source layer memory descriptor.
|
|
memory::desc diff_src_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_LAYER);
|
|
}
|
|
|
|
/// Returns diff AUGRU attention memory descriptor.
|
|
/// @returns Diff AUGRU attention memory descriptor.
|
|
memory::desc diff_augru_attention_desc() const {
|
|
return base::query_md(
|
|
query::exec_arg_md, DNNL_ARG_DIFF_AUGRU_ATTENTION);
|
|
}
|
|
|
|
/// Returns diff source iteration memory descriptor.
|
|
/// @returns Diff source iteration memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff source iteration parameter.
|
|
memory::desc diff_src_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER);
|
|
}
|
|
|
|
/// Returns diff source recurrent cell state memory descriptor.
|
|
/// @returns Diff source recurrent cell state memory descriptor.
|
|
memory::desc diff_src_iter_c_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_SRC_ITER_C);
|
|
}
|
|
|
|
/// Returns diff weights layer memory descriptor.
|
|
/// @returns Diff weights layer memory descriptor.
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_LAYER);
|
|
}
|
|
|
|
/// Returns diff weights iteration memory descriptor.
|
|
/// @returns Diff weights iteration memory descriptor.
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_ITER);
|
|
}
|
|
|
|
/// Returns diff weights peephole memory descriptor.
|
|
/// @returns Diff weights peephole memory descriptor.
|
|
memory::desc diff_weights_peephole_desc() const {
|
|
return base::query_md(
|
|
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
|
|
}
|
|
|
|
/// Returns diff weights projection memory descriptor.
|
|
/// @returns Diff weights projection memory descriptor.
|
|
memory::desc diff_weights_projection_desc() const {
|
|
return base::query_md(
|
|
query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
|
|
}
|
|
|
|
/// Returns diff bias memory descriptor.
|
|
/// @returns Diff bias memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff bias parameter.
|
|
memory::desc diff_bias_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_BIAS);
|
|
}
|
|
|
|
/// Returns diff destination layer memory descriptor.
|
|
/// @returns Diff destination layer memory descriptor.
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_LAYER);
|
|
}
|
|
|
|
/// Returns diff destination iteration memory descriptor.
|
|
/// @returns Diff destination iteration memory descriptor.
|
|
/// @returns A zero memory descriptor if the primitive does not have a
|
|
/// diff destination iteration parameter.
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER);
|
|
}
|
|
|
|
/// Returns diff destination recurrent cell state memory descriptor.
|
|
/// @returns Diff destination recurrent cell state memory descriptor.
|
|
memory::desc diff_dst_iter_c_desc() const {
|
|
return base::query_md(query::exec_arg_md, DNNL_ARG_DIFF_DST_ITER_C);
|
|
}
|
|
|
|
protected:
|
|
using rnn_base = rnn_primitive_desc_base;
|
|
|
|
// (Deliberately not using doxygen comments)
|
|
//
|
|
// Constructs an RNN primitive descriptor base from a C API primitive
|
|
// descriptor while checking that it actually describes the expected
|
|
// primitive by comparing propagation and primitive kinds. Caller can
|
|
// pass two options propagation kinds. This is typically used to check
|
|
// that propagation kind is inference or training forward propagation.
|
|
//
|
|
// @param pd C API primitive descriptor.
|
|
// @param prop_kind1 Expected propagation kind.
|
|
// @param prop_kind2 Expected propagation kind.
|
|
// @param cell_kind Expected cell kind.
|
|
rnn_primitive_desc_base(dnnl_primitive_desc_t pd,
|
|
dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
|
|
dnnl::algorithm cell_kind) {
|
|
|
|
dnnl_status_t rc;
|
|
|
|
dnnl_primitive_kind_t q_primitive_kind;
|
|
rc = dnnl_primitive_desc_query(
|
|
pd, dnnl_query_primitive_kind, 0, &q_primitive_kind);
|
|
error::wrap_c_api(rc,
|
|
"could not retrieve a primitive kind from a primitive "
|
|
"descriptor for an RNN primitive");
|
|
|
|
dnnl_prop_kind_t q_prop_kind;
|
|
rc = dnnl_primitive_desc_query(
|
|
pd, dnnl_query_prop_kind, 0, &q_prop_kind);
|
|
error::wrap_c_api(rc,
|
|
"could not retrieve a propagation kind from a primitive "
|
|
"descriptor for an RNN primitive");
|
|
|
|
dnnl_alg_kind_t q_cell_kind;
|
|
rc = dnnl_primitive_desc_query(
|
|
pd, dnnl_query_cell_kind, 0, &q_cell_kind);
|
|
error::wrap_c_api(rc,
|
|
"could not retrieve a cell kind from a primitive descriptor "
|
|
"for an RNN primitive");
|
|
|
|
dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
|
|
dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
|
|
dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
|
|
|
|
bool ok = q_primitive_kind == dnnl_rnn
|
|
&& (q_prop_kind == c_prop_kind1 || q_prop_kind == c_prop_kind2)
|
|
&& q_cell_kind == c_cell_kind;
|
|
|
|
if (!ok)
|
|
DNNL_THROW_ERROR(dnnl_invalid_arguments,
|
|
"mismatch between expected and provided descriptors for an "
|
|
"RNN primitive");
|
|
|
|
reset_with_clone(pd);
|
|
}
|
|
|
|
// Constructs an RNN forward propagation primitive descriptor base for
|
|
// any cell kind.
|
|
rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
|
|
prop_kind aprop_kind, algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc *src_iter_c_desc,
|
|
const memory::desc *attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc *weights_peephole_desc,
|
|
const memory::desc *weights_projection_desc,
|
|
const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc *dst_iter_c_desc, rnn_flags flags, float alpha,
|
|
float beta, const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
const char *msg
|
|
= "could not create a primitive descriptor for a requested "
|
|
"cell kind";
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
switch (cell_kind) {
|
|
case algorithm::vanilla_rnn:
|
|
status = dnnl_vanilla_rnn_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(activation),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
convert_to_c(flags), alpha, beta, attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the vanilla RNN forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_lstm:
|
|
status = dnnl_lstm_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(src_iter_c_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
optional_arg(weights_peephole_desc),
|
|
optional_arg(weights_projection_desc), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
optional_arg(dst_iter_c_desc), convert_to_c(flags),
|
|
attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LSTM forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_gru:
|
|
status = dnnl_gru_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
convert_to_c(flags), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the GRU forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::lbr_gru:
|
|
status = dnnl_lbr_gru_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
convert_to_c(flags), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LBR GRU forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_augru:
|
|
status = dnnl_augru_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(attention_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
bias_desc.get(), dst_layer_desc.get(),
|
|
dst_iter_desc.get(), convert_to_c(flags), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the AUGRU forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::lbr_augru:
|
|
status = dnnl_lbr_augru_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(attention_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
bias_desc.get(), dst_layer_desc.get(),
|
|
dst_iter_desc.get(), convert_to_c(flags), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LBR AUGRU forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.";
|
|
break;
|
|
default: status = dnnl_unimplemented;
|
|
}
|
|
|
|
if (!allow_empty) error::wrap_c_api(status, msg);
|
|
reset(pd);
|
|
}
|
|
|
|
// Constructs an RNN backward propagation primitive descriptor base for
|
|
// any cell kind.
|
|
rnn_primitive_desc_base(const engine &aengine, algorithm cell_kind,
|
|
prop_kind aprop_kind, algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc *src_iter_c_desc,
|
|
const memory::desc *attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc *weights_peephole_desc,
|
|
const memory::desc *weights_projection_desc,
|
|
const memory::desc &bias_desc, const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc *dst_iter_c_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc *diff_src_iter_c_desc,
|
|
const memory::desc *diff_attention_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc *diff_weights_peephole_desc,
|
|
const memory::desc *diff_weights_projection_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const memory::desc *diff_dst_iter_c_desc, rnn_flags flags,
|
|
float alpha, float beta, const rnn_primitive_desc_base &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
dnnl_status_t status = dnnl_success;
|
|
const char *msg = "";
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
switch (cell_kind) {
|
|
case algorithm::vanilla_rnn:
|
|
status = dnnl_vanilla_rnn_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(activation),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
diff_src_layer_desc.get(), diff_src_iter_desc.get(),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(), diff_bias_desc.get(),
|
|
diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
|
|
convert_to_c(flags), alpha, beta, hint_fwd_pd.get(),
|
|
attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the vanilla RNN backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_lstm:
|
|
status = dnnl_lstm_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(src_iter_c_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
optional_arg(weights_peephole_desc),
|
|
optional_arg(weights_projection_desc), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
optional_arg(dst_iter_c_desc),
|
|
diff_src_layer_desc.get(), diff_src_iter_desc.get(),
|
|
optional_arg(diff_src_iter_c_desc),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(),
|
|
optional_arg(diff_weights_peephole_desc),
|
|
optional_arg(diff_weights_projection_desc),
|
|
diff_bias_desc.get(), diff_dst_layer_desc.get(),
|
|
diff_dst_iter_desc.get(),
|
|
optional_arg(diff_dst_iter_c_desc), convert_to_c(flags),
|
|
hint_fwd_pd.get(), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LSTM backward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_gru:
|
|
status = dnnl_gru_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
diff_src_layer_desc.get(), diff_src_iter_desc.get(),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(), diff_bias_desc.get(),
|
|
diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
|
|
convert_to_c(flags), hint_fwd_pd.get(), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the GRU backward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::lbr_gru:
|
|
status = dnnl_lbr_gru_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), weights_layer_desc.get(),
|
|
weights_iter_desc.get(), bias_desc.get(),
|
|
dst_layer_desc.get(), dst_iter_desc.get(),
|
|
diff_src_layer_desc.get(), diff_src_iter_desc.get(),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(), diff_bias_desc.get(),
|
|
diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
|
|
convert_to_c(flags), hint_fwd_pd.get(), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LBR GRU backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.";
|
|
break;
|
|
case algorithm::vanilla_augru:
|
|
status = dnnl_augru_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(attention_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
bias_desc.get(), dst_layer_desc.get(),
|
|
dst_iter_desc.get(), diff_src_layer_desc.get(),
|
|
diff_src_iter_desc.get(),
|
|
optional_arg(diff_attention_desc),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(), diff_bias_desc.get(),
|
|
diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
|
|
convert_to_c(flags), hint_fwd_pd.get(), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the AUGRU backward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.";
|
|
break;
|
|
case algorithm::lbr_augru:
|
|
status = dnnl_lbr_augru_backward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
dnnl::convert_to_c(direction), src_layer_desc.get(),
|
|
src_iter_desc.get(), optional_arg(attention_desc),
|
|
weights_layer_desc.get(), weights_iter_desc.get(),
|
|
bias_desc.get(), dst_layer_desc.get(),
|
|
dst_iter_desc.get(), diff_src_layer_desc.get(),
|
|
diff_src_iter_desc.get(),
|
|
optional_arg(diff_attention_desc),
|
|
diff_weights_layer_desc.get(),
|
|
diff_weights_iter_desc.get(), diff_bias_desc.get(),
|
|
diff_dst_layer_desc.get(), diff_dst_iter_desc.get(),
|
|
convert_to_c(flags), hint_fwd_pd.get(), attr.get());
|
|
msg = "could not create a primitive descriptor for "
|
|
"the LBR AUGRU backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.";
|
|
break;
|
|
default: status = dnnl_unimplemented;
|
|
}
|
|
if (!allow_empty) error::wrap_c_api(status, msg);
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Vanilla RNN forward propagation primitive.
|
|
struct vanilla_rnn_forward : public primitive {
|
|
/// Primitive descriptor for a vanilla RNN forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the RNN forward propagation primitive
|
|
/// should not use them and should default to zero values instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc can be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param activation Activation kind. Possible values are
|
|
/// #dnnl::algorithm::eltwise_relu,
|
|
/// #dnnl::algorithm::eltwise_tanh, or
|
|
/// #dnnl::algorithm::eltwise_logistic.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
|
|
aprop_kind, activation, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
|
|
0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive with alpha parameter.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the RNN forward propagation primitive
|
|
/// should not use them and should default to zero values instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc can be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param activation Activation kind. Possible values are
|
|
/// #dnnl::algorithm::eltwise_relu,
|
|
/// #dnnl::algorithm::eltwise_tanh, or
|
|
/// #dnnl::algorithm::eltwise_logistic.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param alpha Negative slope if activation is
|
|
/// #dnnl::algorithm::eltwise_relu.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc, float alpha,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
|
|
aprop_kind, activation, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
|
|
alpha, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::vanilla_rnn) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
|
|
algorithm get_activation_kind() const {
|
|
return base::get_activation_kind();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
vanilla_rnn_forward() = default;
|
|
|
|
/// Constructs a vanilla RNN forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive.
|
|
vanilla_rnn_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a vanilla RNN forward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a vanilla RNN forward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
vanilla_rnn_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Vanilla RNN backward propagation primitive.
|
|
struct vanilla_rnn_backward : public primitive {
|
|
/// Primitive descriptor for an RNN backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the RNN backward propagation
|
|
/// primitive should not use the respective data and should use zero
|
|
/// values instead.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param activation Activation kind. Possible values are
|
|
/// #dnnl::algorithm::eltwise_relu,
|
|
/// #dnnl::algorithm::eltwise_tanh, or
|
|
/// #dnnl::algorithm::eltwise_logistic.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
|
|
aprop_kind, activation, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
|
|
diff_src_iter_desc, nullptr, nullptr,
|
|
diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
|
|
nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive with an alpha parameter.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the RNN backward propagation
|
|
/// primitive should not use the respective data and should use zero
|
|
/// values instead.
|
|
///
|
|
/// @note
|
|
/// All the memory descriptors may be initialized with the
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param activation Activation kind. Possible values are
|
|
/// #dnnl::algorithm::eltwise_relu,
|
|
/// #dnnl::algorithm::eltwise_tanh, or
|
|
/// #dnnl::algorithm::eltwise_logistic.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param alpha Negative slope if activation is
|
|
/// #dnnl::algorithm::eltwise_relu.
|
|
/// @param hint_fwd_pd Primitive descriptor for a vanilla RNN
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm activation, rnn_direction direction,
|
|
const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc, float alpha,
|
|
const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_rnn,
|
|
aprop_kind, activation, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
|
|
diff_src_iter_desc, nullptr, nullptr,
|
|
diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
|
|
nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, alpha, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
|
dnnl::algorithm::vanilla_rnn) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_activation_kind()const
|
|
algorithm get_activation_kind() const {
|
|
return base::get_activation_kind();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_alpha()const
|
|
float get_alpha() const { return base::get_alpha(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_beta()const
|
|
float get_beta() const { return base::get_beta(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
vanilla_rnn_backward() = default;
|
|
|
|
/// Constructs a vanilla RNN backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive.
|
|
vanilla_rnn_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a vanilla RNN backward propagation primitive from
|
|
/// a cache blob.
|
|
/// @param pd Primitive descriptor for a vanilla RNN backward
|
|
/// propagation primitive.
|
|
/// @param cache_blob Cache blob.
|
|
vanilla_rnn_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LSTM forward propagation primitive.
|
|
struct lstm_forward : public primitive {
|
|
/// Primitive descriptor for an LSTM forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an LSTM (with or without
|
|
/// peephole and with or without projection) forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// - @p weights_peephole_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// The @p weights_projection_desc may point to a zero memory
|
|
/// descriptor. This would then indicate that the LSTM doesn't have
|
|
/// recurrent projection layer.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors can be initialized with an
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
|
/// applied to the cell states (according to the Peephole LSTM
|
|
/// formula).
|
|
/// @param weights_projection_desc Memory descriptor for the weights
|
|
/// applied to the hidden states to get the recurrent projection
|
|
/// (according to the Projection LSTM formula).
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &weights_peephole_desc,
|
|
const memory::desc &weights_projection_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc,
|
|
&weights_peephole_desc, &weights_projection_desc, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
|
|
rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LSTM (with or without
|
|
/// peephole) forward propagation primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// - @p weights_peephole_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors can be initialized with an
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
|
/// applied to the cell states (according to the Peephole LSTM
|
|
/// formula).
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &weights_peephole_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc,
|
|
&weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
|
|
dst_iter_desc, &dst_iter_c_desc, rnn_flags::undef, 0.0f,
|
|
0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LSTM forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors can be initialized with an
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc, nullptr, nullptr,
|
|
bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
|
|
rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LSTM forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LSTM forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::vanilla_lstm) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_c_desc() const {
|
|
return rnn_base::src_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
|
|
memory::desc weights_peephole_desc() const {
|
|
return rnn_base::weights_peephole_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
|
|
memory::desc weights_projection_desc() const {
|
|
return rnn_base::weights_projection_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc dst_iter_c_desc() const {
|
|
return rnn_base::dst_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lstm_forward() = default;
|
|
|
|
/// Constructs an LSTM forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LSTM forward propagation
|
|
/// primitive.
|
|
lstm_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LSTM forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LSTM forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lstm_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LSTM backward propagation primitive.
|
|
struct lstm_backward : public primitive {
|
|
/// Primitive descriptor for an LSTM backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs an LSTM (with or without peephole and with or without
|
|
/// projection) primitive descriptor for backward propagation
|
|
/// using @p prop_kind, @p direction, and memory descriptors.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
|
|
/// - @p weights_peephole_desc together with
|
|
/// @p diff_weights_peephole_desc
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc,
|
|
/// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// The @p weights_projection_desc together with @p
|
|
/// diff_weights_projection_desc may point to a zero memory descriptor.
|
|
/// This would then indicate that the LSTM doesn't have recurrent
|
|
/// projection layer.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors can be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
|
/// applied to the cell states (according to the Peephole LSTM
|
|
/// formula).
|
|
/// @param weights_projection_desc Memory descriptor for the weights
|
|
/// applied to the hidden states to get the recurrent projection
|
|
/// (according to the Projection LSTM formula).
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
|
/// input recurrent cell state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
|
|
/// weights applied to the cell states (according to the Peephole
|
|
/// LSTM formula).
|
|
/// @param diff_weights_projection_desc Memory descriptor for the diff
|
|
/// of weights applied to the hidden states to get the recurrent
|
|
/// projection (according to the Projection LSTM formula).
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
|
/// output recurrent cell state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for an LSTM
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &weights_peephole_desc,
|
|
const memory::desc &weights_projection_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_src_iter_c_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_weights_peephole_desc,
|
|
const memory::desc &diff_weights_projection_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const memory::desc &diff_dst_iter_c_desc,
|
|
const lstm_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc,
|
|
&weights_peephole_desc, &weights_projection_desc, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
|
|
diff_src_layer_desc, diff_src_iter_desc,
|
|
&diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
|
|
diff_weights_iter_desc, &diff_weights_peephole_desc,
|
|
&diff_weights_projection_desc, diff_bias_desc,
|
|
diff_dst_layer_desc, diff_dst_iter_desc,
|
|
&diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs an LSTM (with or without peephole) primitive descriptor
|
|
/// for backward propagation using @p prop_kind, @p direction,
|
|
/// and memory descriptors.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
|
|
/// - @p weights_peephole_desc together with
|
|
/// @p diff_weights_peephole_desc
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc,
|
|
/// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param weights_peephole_desc Memory descriptor for the weights
|
|
/// applied to the cell states (according to the Peephole LSTM
|
|
/// formula).
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
|
/// input recurrent cell state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_weights_peephole_desc Memory descriptor for the diff of
|
|
/// weights applied to the cell states (according to the Peephole
|
|
/// LSTM formula).
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
|
/// output recurrent cell state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for an LSTM
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &weights_peephole_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_src_iter_c_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_weights_peephole_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const memory::desc &diff_dst_iter_c_desc,
|
|
const lstm_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc,
|
|
&weights_peephole_desc, nullptr, bias_desc, dst_layer_desc,
|
|
dst_iter_desc, &dst_iter_c_desc, diff_src_layer_desc,
|
|
diff_src_iter_desc, &diff_src_iter_c_desc, nullptr,
|
|
diff_weights_layer_desc, diff_weights_iter_desc,
|
|
&diff_weights_peephole_desc, nullptr, diff_bias_desc,
|
|
diff_dst_layer_desc, diff_dst_iter_desc,
|
|
&diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs an LSTM primitive descriptor for backward propagation
|
|
/// using @p prop_kind, @p direction, and memory descriptors.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p src_iter_c_desc,
|
|
/// @p diff_src_iter_desc, and @p diff_src_iter_c_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p dst_iter_c_desc,
|
|
/// @p diff_dst_iter_desc, and @p diff_dst_iter_c_desc.
|
|
///
|
|
/// This would then indicate that the LSTM backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param src_iter_c_desc Memory descriptor for the input recurrent
|
|
/// cell state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param dst_iter_c_desc Memory descriptor for the output recurrent
|
|
/// cell state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_src_iter_c_desc Memory descriptor for the diff of
|
|
/// input recurrent cell state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_dst_iter_c_desc Memory descriptor for the diff of
|
|
/// output recurrent cell state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for a convolution
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &src_iter_c_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &dst_iter_c_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_src_iter_c_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const memory::desc &diff_dst_iter_c_desc,
|
|
const lstm_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_lstm,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, &src_iter_c_desc, nullptr,
|
|
weights_layer_desc, weights_iter_desc, nullptr, nullptr,
|
|
bias_desc, dst_layer_desc, dst_iter_desc, &dst_iter_c_desc,
|
|
diff_src_layer_desc, diff_src_iter_desc,
|
|
&diff_src_iter_c_desc, nullptr, diff_weights_layer_desc,
|
|
diff_weights_iter_desc, nullptr, nullptr, diff_bias_desc,
|
|
diff_dst_layer_desc, diff_dst_iter_desc,
|
|
&diff_dst_iter_c_desc, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LSTM backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LSTM backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
|
dnnl::algorithm::vanilla_lstm) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_c_desc() const {
|
|
return rnn_base::src_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_peephole_desc()const
|
|
memory::desc weights_peephole_desc() const {
|
|
return rnn_base::weights_peephole_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_projection_desc()const
|
|
memory::desc weights_projection_desc() const {
|
|
return rnn_base::weights_projection_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc dst_iter_c_desc() const {
|
|
return rnn_base::dst_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_c_desc()const
|
|
memory::desc diff_src_iter_c_desc() const {
|
|
return rnn_base::diff_src_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_peephole_desc()const
|
|
memory::desc diff_weights_peephole_desc() const {
|
|
return rnn_base::diff_weights_peephole_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_projection_desc()const
|
|
memory::desc diff_weights_projection_desc() const {
|
|
return rnn_base::diff_weights_projection_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_c_desc()const
|
|
memory::desc diff_dst_iter_c_desc() const {
|
|
return rnn_base::diff_dst_iter_c_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lstm_backward() = default;
|
|
|
|
/// Constructs an LSTM backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LSTM backward propagation
|
|
/// primitive.
|
|
lstm_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LSTM backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LSTM backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lstm_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// GRU forward propagation primitive.
|
|
struct gru_forward : public primitive {
|
|
/// Primitive descriptor for a GRU forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a GRU forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the GRU forward propagation primitive
|
|
/// should not use them and should default to zero values instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc may be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
|
|
0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a GRU forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a GRU forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::vanilla_gru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
gru_forward() = default;
|
|
|
|
/// Constructs a GRU forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a GRU forward propagation
|
|
/// primitive.
|
|
gru_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a GRU forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a GRU forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
gru_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// GRU backward propagation primitive.
|
|
struct gru_backward : public primitive {
|
|
/// Primitive descriptor for a GRU backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a GRU backward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the GRU backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for a GRU
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const gru_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_gru,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, nullptr, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
|
|
diff_src_iter_desc, nullptr, nullptr,
|
|
diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
|
|
nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a GRU backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a GRU backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
|
dnnl::algorithm::vanilla_gru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
gru_backward() = default;
|
|
|
|
/// Constructs a GRU backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a GRU backward propagation
|
|
/// primitive.
|
|
gru_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a GRU backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a GRU backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
gru_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LBR GRU forward propagation primitive.
|
|
struct lbr_gru_forward : public primitive {
|
|
/// Primitive descriptor for an LBR GRU forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for LBR GRU forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the LBR GRU forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc may be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
|
|
algorithm::undef, direction, src_layer_desc, src_iter_desc,
|
|
nullptr, nullptr, weights_layer_desc, weights_iter_desc,
|
|
nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
|
|
nullptr, rnn_flags::undef, 0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a LBR GRU forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a LBR GRU forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::lbr_gru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lbr_gru_forward() = default;
|
|
|
|
/// Constructs an LBR GRU forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LBR GRU forward propagation
|
|
/// primitive.
|
|
lbr_gru_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LBR GRU forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LBR GRU forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lbr_gru_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LBR GRU backward propagation primitive.
|
|
struct lbr_gru_backward : public primitive {
|
|
/// Primitive descriptor for an LBR GRU backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for LBR GRU backward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the LBR GRU backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for an LBR GRU
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const lbr_gru_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::lbr_gru, aprop_kind,
|
|
algorithm::undef, direction, src_layer_desc, src_iter_desc,
|
|
nullptr, nullptr, weights_layer_desc, weights_iter_desc,
|
|
nullptr, nullptr, bias_desc, dst_layer_desc, dst_iter_desc,
|
|
nullptr, diff_src_layer_desc, diff_src_iter_desc, nullptr,
|
|
nullptr, diff_weights_layer_desc, diff_weights_iter_desc,
|
|
nullptr, nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a LBR GRU backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a LBR GRU backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(
|
|
pd, dnnl::prop_kind::backward, dnnl::algorithm::lbr_gru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lbr_gru_backward() = default;
|
|
|
|
/// Constructs an LBR GRU backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LBR GRU backward propagation
|
|
/// primitive.
|
|
lbr_gru_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LBR GRU backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LBR GRU backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lbr_gru_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// AUGRU forward propagation primitive.
|
|
struct augru_forward : public primitive {
|
|
/// Primitive descriptor for an AUGRU forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an AUGRU forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the AUGRU forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc may be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param attention_desc Memory descriptor for the attention vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
|
|
0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an AUGRU forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an AUGRU forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::vanilla_augru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
|
|
memory::desc attention_desc() const {
|
|
return rnn_base::augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
augru_forward() = default;
|
|
|
|
/// Constructs an AUGRU forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an AUGRU forward propagation
|
|
/// primitive.
|
|
augru_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an AUGRU forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an AUGRU forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
augru_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// AUGRU backward propagation primitive.
|
|
struct augru_backward : public primitive {
|
|
/// Descriptor for an AUGRU backward propagation primitive.
|
|
/// Primitive descriptor for an AUGRU backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an AUGRU backward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the AUGRU backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param attention_desc Memory descriptor for the attention vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_attention_desc Memory descriptor for the diff of
|
|
/// attention vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for an AUGRU
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_attention_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const augru_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::vanilla_augru,
|
|
aprop_kind, algorithm::undef, direction, src_layer_desc,
|
|
src_iter_desc, nullptr, &attention_desc, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
|
|
diff_src_iter_desc, nullptr, &diff_attention_desc,
|
|
diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
|
|
nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an AUGRU backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an AUGRU backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
|
dnnl::algorithm::vanilla_augru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
|
|
memory::desc attention_desc() const {
|
|
return rnn_base::augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
|
|
memory::desc diff_attention_desc() const {
|
|
return rnn_base::diff_augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
augru_backward() = default;
|
|
|
|
/// Constructs an AUGRU backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an AUGRU backward propagation
|
|
/// primitive.
|
|
augru_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an AUGRU backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an AUGRU backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
augru_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LBR AUGRU forward propagation primitive.
|
|
struct lbr_augru_forward : public primitive {
|
|
/// Descriptor for an LBR AUGRU forward propagation primitive.
|
|
|
|
/// Primitive descriptor for an LBR AUGRU forward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for LBR AUGRU forward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc,
|
|
/// - @p bias_desc,
|
|
/// - @p dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the LBR AUGRU forward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors except @p src_iter_desc may be
|
|
/// initialized with an #dnnl::memory::format_tag::any value of @p
|
|
/// format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param attention_desc Memory descriptor for the attention vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
|
|
algorithm::undef, direction, src_layer_desc, src_iter_desc,
|
|
nullptr, &attention_desc, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, rnn_flags::undef,
|
|
0.0f, 0.0f, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LBR AUGRU forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LBR AUGRU forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference,
|
|
dnnl::algorithm::lbr_augru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
|
|
memory::desc attention_desc() const {
|
|
return rnn_base::augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lbr_augru_forward() = default;
|
|
|
|
/// Constructs an LBR AUGRU forward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LBR AUGRU forward propagation
|
|
/// primitive.
|
|
lbr_augru_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LBR AUGRU forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LBR AUGRU forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lbr_augru_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// LBR AUGRU backward propagation primitive.
|
|
struct lbr_augru_backward : public primitive {
|
|
/// Primitive descriptor for an LBR AUGRU backward propagation primitive.
|
|
struct primitive_desc : public rnn_primitive_desc_base {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for LBR AUGRU backward propagation
|
|
/// primitive.
|
|
///
|
|
/// The following arguments may point to a zero memory descriptor:
|
|
/// - @p src_iter_desc together with @p diff_src_iter_desc,
|
|
/// - @p bias_desc together with @p diff_bias_desc,
|
|
/// - @p dst_iter_desc together with @p diff_dst_iter_desc.
|
|
///
|
|
/// This would then indicate that the LBR AUGRU backward propagation
|
|
/// primitive should not use them and should default to zero values
|
|
/// instead.
|
|
///
|
|
/// @note
|
|
/// All memory descriptors may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Must be
|
|
/// #dnnl::prop_kind::backward.
|
|
/// @param direction RNN direction. See @ref dnnl::rnn_direction for
|
|
/// more info.
|
|
/// @param src_layer_desc Memory descriptor for the input vector.
|
|
/// @param src_iter_desc Memory descriptor for the input recurrent
|
|
/// hidden state vector.
|
|
/// @param attention_desc Memory descriptor for the attention vector.
|
|
/// @param weights_layer_desc Memory descriptor for the weights
|
|
/// applied to the layer input.
|
|
/// @param weights_iter_desc Memory descriptor for the weights applied
|
|
/// to the recurrent input.
|
|
/// @param bias_desc Bias memory descriptor.
|
|
/// @param dst_layer_desc Memory descriptor for the output vector.
|
|
/// @param dst_iter_desc Memory descriptor for the output recurrent
|
|
/// hidden state vector.
|
|
/// @param diff_src_layer_desc Memory descriptor for the diff of input
|
|
/// vector.
|
|
/// @param diff_src_iter_desc Memory descriptor for the diff of input
|
|
/// recurrent hidden state vector.
|
|
/// @param diff_attention_desc Memory descriptor for the diff of
|
|
/// attention vector.
|
|
/// @param diff_weights_layer_desc Memory descriptor for the diff of
|
|
/// weights applied to the layer input.
|
|
/// @param diff_weights_iter_desc Memory descriptor for the diff of
|
|
/// weights applied to the recurrent input.
|
|
/// @param diff_bias_desc Diff bias memory descriptor.
|
|
/// @param diff_dst_layer_desc Memory descriptor for the diff of
|
|
/// output vector.
|
|
/// @param diff_dst_iter_desc Memory descriptor for the diff of output
|
|
/// recurrent hidden state vector.
|
|
/// @param hint_fwd_pd Primitive descriptor for an LBR AUGRU
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
rnn_direction direction, const memory::desc &src_layer_desc,
|
|
const memory::desc &src_iter_desc,
|
|
const memory::desc &attention_desc,
|
|
const memory::desc &weights_layer_desc,
|
|
const memory::desc &weights_iter_desc,
|
|
const memory::desc &bias_desc,
|
|
const memory::desc &dst_layer_desc,
|
|
const memory::desc &dst_iter_desc,
|
|
const memory::desc &diff_src_layer_desc,
|
|
const memory::desc &diff_src_iter_desc,
|
|
const memory::desc &diff_attention_desc,
|
|
const memory::desc &diff_weights_layer_desc,
|
|
const memory::desc &diff_weights_iter_desc,
|
|
const memory::desc &diff_bias_desc,
|
|
const memory::desc &diff_dst_layer_desc,
|
|
const memory::desc &diff_dst_iter_desc,
|
|
const lbr_augru_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: rnn_primitive_desc_base(aengine, algorithm::lbr_augru, aprop_kind,
|
|
algorithm::undef, direction, src_layer_desc, src_iter_desc,
|
|
nullptr, &attention_desc, weights_layer_desc,
|
|
weights_iter_desc, nullptr, nullptr, bias_desc,
|
|
dst_layer_desc, dst_iter_desc, nullptr, diff_src_layer_desc,
|
|
diff_src_iter_desc, nullptr, &diff_attention_desc,
|
|
diff_weights_layer_desc, diff_weights_iter_desc, nullptr,
|
|
nullptr, diff_bias_desc, diff_dst_layer_desc,
|
|
diff_dst_iter_desc, nullptr, rnn_flags::undef, 0.0f, 0.0f,
|
|
hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for an LBR AUGRU backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for an LBR AUGRU backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: rnn_primitive_desc_base(pd, dnnl::prop_kind::backward,
|
|
dnnl::algorithm::lbr_augru) {}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_layer_desc()const
|
|
memory::desc src_layer_desc() const {
|
|
return rnn_base::src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::src_iter_desc()const
|
|
memory::desc src_iter_desc() const { return rnn_base::src_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::augru_attention_desc()const
|
|
memory::desc attention_desc() const {
|
|
return rnn_base::augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_layer_desc()const
|
|
memory::desc weights_layer_desc() const {
|
|
return rnn_base::weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::weights_iter_desc()const
|
|
memory::desc weights_iter_desc() const {
|
|
return rnn_base::weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::bias_desc()const
|
|
memory::desc bias_desc() const { return rnn_base::bias_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_layer_desc()const
|
|
memory::desc dst_layer_desc() const {
|
|
return rnn_base::dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::dst_iter_desc()const
|
|
memory::desc dst_iter_desc() const { return rnn_base::dst_iter_desc(); }
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const {
|
|
return rnn_base::workspace_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_layer_desc()const
|
|
memory::desc diff_src_layer_desc() const {
|
|
return rnn_base::diff_src_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_src_iter_desc()const
|
|
memory::desc diff_src_iter_desc() const {
|
|
return rnn_base::diff_src_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_augru_attention_desc()const
|
|
memory::desc diff_attention_desc() const {
|
|
return rnn_base::diff_augru_attention_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_layer_desc()const
|
|
memory::desc diff_weights_layer_desc() const {
|
|
return rnn_base::diff_weights_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_weights_iter_desc()const
|
|
memory::desc diff_weights_iter_desc() const {
|
|
return rnn_base::diff_weights_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_bias_desc()const
|
|
memory::desc diff_bias_desc() const {
|
|
return rnn_base::diff_bias_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_layer_desc()const
|
|
memory::desc diff_dst_layer_desc() const {
|
|
return rnn_base::diff_dst_layer_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::rnn_primitive_desc_base::diff_dst_iter_desc()const
|
|
memory::desc diff_dst_iter_desc() const {
|
|
return rnn_base::diff_dst_iter_desc();
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_cell_kind()const
|
|
algorithm get_cell_kind() const { return base::get_cell_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_direction()const
|
|
rnn_direction get_direction() const { return base::get_direction(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
lbr_augru_backward() = default;
|
|
|
|
/// Constructs an LBR AUGRU backward propagation primitive.
|
|
/// @param pd Primitive descriptor for an LBR AUGRU backward propagation
|
|
/// primitive.
|
|
lbr_augru_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an LBR AUGRU backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an LBR AUGRU backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
lbr_augru_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_rnn
|
|
|
|
/// @addtogroup dnnl_api_shuffle Shuffle
|
|
///
|
|
/// A primitive to shuffle tensor data along an axis.
|
|
///
|
|
/// @sa @ref dev_guide_shuffle in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Shuffle forward propagation primitive.
|
|
struct shuffle_forward : public primitive {
|
|
/// Primitive descriptor for a shuffle forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a shuffle forward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param axis The axis along which the data is shuffled.
|
|
/// @param group_size Shuffle group size.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
int axis, int group_size,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_shuffle_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), dst_desc.get(), axis, group_size,
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the shuffle forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a shuffle forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a shuffle forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_axis()const
|
|
int get_axis() const { return base::get_axis(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_group_size()const
|
|
memory::dim get_group_size() const { return base::get_group_size(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
shuffle_forward() = default;
|
|
|
|
/// Constructs a shuffle forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a shuffle forward propagation
|
|
/// primitive.
|
|
shuffle_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a shuffle forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a shuffle forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
shuffle_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Shuffle backward propagation primitive.
|
|
struct shuffle_backward : public primitive {
|
|
/// Primitive descriptor for a shuffle backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a shuffle backward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param axis The axis along which the data is shuffled.
|
|
/// @param group_size Shuffle group size.
|
|
/// @param hint_fwd_pd Primitive descriptor for a shuffle forward
|
|
/// propagation primitive. It is used as a hint for deciding which
|
|
/// memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, int axis, int group_size,
|
|
const shuffle_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_shuffle_backward_primitive_desc_create(
|
|
&pd, aengine.get(), diff_src_desc.get(),
|
|
diff_dst_desc.get(), axis, group_size, hint_fwd_pd.get(),
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the shuffle backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a shuffle backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a shuffle backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_axis()const
|
|
int get_axis() const { return base::get_axis(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_group_size()const
|
|
memory::dim get_group_size() const { return base::get_group_size(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
shuffle_backward() = default;
|
|
|
|
/// Constructs a shuffle backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a shuffle backward propagation
|
|
/// primitive.
|
|
shuffle_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a shuffle backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a shuffle backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
shuffle_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_shuffle
|
|
|
|
/// @addtogroup dnnl_api_binary Binary
|
|
///
|
|
/// A primitive to perform tensor operations over two tensors.
|
|
///
|
|
/// @sa @ref dev_guide_binary in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Elementwise binary operator primitive.
|
|
struct binary : public primitive {
|
|
/// Primitive descriptor for an elementwise binary operator primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for an elementwise binary operator
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Elementwise binary algorithm.
|
|
/// @param src0 Memory descriptor for source tensor #0.
|
|
/// @param src1 Memory descriptor for source tensor #1.
|
|
/// @param dst Memory descriptor for destination tensor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src0, const memory::desc &src1,
|
|
const memory::desc &dst,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_binary_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
|
|
src1.get(), dst.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the binary operation primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for an elementwise binary operator
|
|
/// primitive with support of ternary operators.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Elementwise binary algorithm.
|
|
/// @param src0 Memory descriptor for source tensor #0.
|
|
/// @param src1 Memory descriptor for source tensor #1.
|
|
/// @param src2 Memory descriptor for source tensor #2 for ternary
|
|
/// operations. Might be empty.
|
|
/// @param dst Memory descriptor for destination tensor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src0, const memory::desc &src1,
|
|
const memory::desc &src2, const memory::desc &dst,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_binary_primitive_desc_create_v2(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aalgorithm), src0.get(),
|
|
src1.get(), src2.get(), dst.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the binary v2 operation primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a binary primitive from a C
|
|
/// API primitive descriptor that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a binary primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::binary) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc(int)const
|
|
memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
|
|
|
|
/// Returns the memory descriptor for source #0.
|
|
memory::desc src0_desc() const { return base::src_desc(0); }
|
|
|
|
/// Returns the memory descriptor for source #1.
|
|
memory::desc src1_desc() const { return base::src_desc(1); }
|
|
|
|
/// Returns the memory descriptor for source #2.
|
|
memory::desc src2_desc() const { return base::src_desc(2); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
binary() = default;
|
|
|
|
/// Constructs an elementwise binary operation primitive.
|
|
/// @param pd Primitive descriptor for an elementwise binary operation
|
|
/// primitive.
|
|
binary(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs an elementwise binary operation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for an elementwise binary operation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
binary(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_binary
|
|
|
|
/// @addtogroup dnnl_api_matmul Matrix Multiplication
|
|
///
|
|
/// A primitive to perform matrix-matrix multiplication. The batched mode
|
|
/// is supported with 3D tensors.
|
|
///
|
|
/// @sa @ref dev_guide_matmul in developer guide
|
|
///
|
|
///
|
|
/// @{
|
|
|
|
/// Matrix multiplication (matmul) primitive.
|
|
struct matmul : public primitive {
|
|
/// Primitive descriptor for a matmul primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a matmul primitive
|
|
/// without bias.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param src_desc Memory descriptor for source (matrix A).
|
|
/// @param weights_desc Memory descriptor for weights (matrix B).
|
|
/// @param dst_desc Memory descriptor for destination (matrix C).
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, src_desc, weights_desc, nullptr, dst_desc,
|
|
attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a matmul primitive with bias.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param src_desc Memory descriptor for source (matrix A).
|
|
/// @param weights_desc Memory descriptor for weights (matrix B).
|
|
/// @param dst_desc Memory descriptor for destination (matrix C).
|
|
/// @param bias_desc Memory descriptor for bias.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc &bias_desc,
|
|
const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, src_desc, weights_desc, &bias_desc,
|
|
dst_desc, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a matmul primitive from a C
|
|
/// API primitive descriptor that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a matmul primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::matmul) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return query_md(query::src_md, 0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::weights_desc()const
|
|
memory::desc weights_desc() const {
|
|
return query_md(query::weights_md, 0);
|
|
}
|
|
|
|
/// @copydoc dnnl::convolution_forward::primitive_desc::bias_desc()const
|
|
memory::desc bias_desc() const {
|
|
return query_md(query::weights_md, 1);
|
|
}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return query_md(query::dst_md, 0); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &weights_desc, const memory::desc *bias_desc,
|
|
const memory::desc &dst_desc, const primitive_attr &attr,
|
|
bool allow_empty) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_matmul_primitive_desc_create(&pd,
|
|
aengine.get(), src_desc.get(), weights_desc.get(),
|
|
optional_arg(bias_desc), dst_desc.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the matmul primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
matmul() = default;
|
|
|
|
/// Constructs a matmul primitive.
|
|
/// @param pd Primitive descriptor for a matmul primitive.
|
|
matmul(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a matmul primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a matmul primitive.
|
|
/// @param cache_blob Cache blob.
|
|
matmul(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_matmul
|
|
|
|
/// @addtogroup dnnl_api_resampling Resampling
|
|
///
|
|
/// A primitive to compute resampling operation on 1D, 2D or 3D data tensor
|
|
/// using Nearest Neighbor, or Linear (Bilinear, Trilinear) interpolation
|
|
/// method.
|
|
///
|
|
/// @sa @ref dev_guide_resampling in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Resampling forward propagation.
|
|
struct resampling_forward : public primitive {
|
|
/// Primitive descriptor for a resampling forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a resampling forward
|
|
/// propagation primitive using source and destination memory
|
|
/// descriptors.
|
|
///
|
|
/// @note
|
|
/// Destination memory descriptor may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm resampling algorithm kind: either
|
|
/// #dnnl::algorithm::resampling_nearest, or
|
|
/// #dnnl::algorithm::resampling_linear
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, nullptr, src_desc,
|
|
&dst_desc, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a resampling forward
|
|
/// propagation primitive using source memory descriptor and
|
|
/// factors.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm resampling algorithm kind: either
|
|
/// #dnnl::algorithm::resampling_nearest, or
|
|
/// #dnnl::algorithm::resampling_linear
|
|
/// @param factors Vector of scaling factors for spatial dimension.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const std::vector<float> &factors,
|
|
const memory::desc &src_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
|
|
src_desc, nullptr, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a resampling forward
|
|
/// propagation primitive.
|
|
///
|
|
/// @note
|
|
/// The destination memory descriptor may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm resampling algorithm kind: either
|
|
/// #dnnl::algorithm::resampling_nearest, or
|
|
/// #dnnl::algorithm::resampling_linear
|
|
/// @param factors Vector of scaling factors for spatial dimension.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const std::vector<float> &factors,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aprop_kind, aalgorithm, &factors,
|
|
src_desc, &dst_desc, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a resampling forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a resampling forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const std::vector<float> *factors,
|
|
const memory::desc &src_desc, const memory::desc *dst_desc,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
if (factors)
|
|
memory::validate_dims(*factors, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_resampling_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
convert_to_c(aalgorithm), optional_arg(factors),
|
|
src_desc.get(), optional_arg(dst_desc), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the resampling forward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
resampling_forward() = default;
|
|
|
|
/// Constructs a resampling forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a resampling forward propagation
|
|
/// primitive.
|
|
resampling_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a resampling forward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a resampling forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
resampling_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Resampling backward propagation primitive.
|
|
struct resampling_backward : public primitive {
|
|
/// Primitive descriptor for resampling backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a resampling backward
|
|
/// propagation primitive using source and destination memory
|
|
/// descriptors.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm resampling algorithm kind: either
|
|
/// #dnnl::algorithm::resampling_nearest, or
|
|
/// #dnnl::algorithm::resampling_linear
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param hint_fwd_pd Primitive descriptor for a resampling
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const resampling_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, nullptr, diff_src_desc,
|
|
diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for resampling backward
|
|
/// propagation primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm resampling algorithm kind: either
|
|
/// #dnnl::algorithm::resampling_nearest, or
|
|
/// #dnnl::algorithm::resampling_linear
|
|
/// @param factors Vector of scaling factors for spatial dimension.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param hint_fwd_pd Primitive descriptor for a resampling
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const std::vector<float> &factors,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const resampling_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false)
|
|
: primitive_desc(aengine, aalgorithm, &factors, diff_src_desc,
|
|
diff_dst_desc, hint_fwd_pd, attr, allow_empty) {}
|
|
|
|
/// Constructs a primitive descriptor for a resampling backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a resampling backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
private:
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const std::vector<float> *factors,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const resampling_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr, bool allow_empty) {
|
|
|
|
if (factors)
|
|
memory::validate_dims(*factors, diff_src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status
|
|
= dnnl_resampling_backward_primitive_desc_create(&pd,
|
|
aengine.get(), convert_to_c(aalgorithm),
|
|
optional_arg(factors), diff_src_desc.get(),
|
|
diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the resampling backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
resampling_backward() = default;
|
|
|
|
/// Constructs a resampling backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a resampling backward propagation
|
|
/// primitive.
|
|
resampling_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a resampling backward propagation primitive from a cache
|
|
/// blob.
|
|
/// @param pd Primitive descriptor for a resampling backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
resampling_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_resampling
|
|
|
|
/// @addtogroup dnnl_api_pooling Pooling
|
|
///
|
|
/// A primitive to perform max or average pooling with dilation.
|
|
///
|
|
/// @sa @ref dev_guide_pooling in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Pooling forward propagation primitive.
|
|
struct pooling_forward : public primitive {
|
|
/// Primitive descriptor for a pooling forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for pooling forward propagation
|
|
/// primitive.
|
|
///
|
|
/// Arrays @p strides, @p kernel, @p dilation, @p padding_l
|
|
/// and @p padding_r contain values for spatial dimensions only and
|
|
/// hence must have the same number of elements as there are spatial
|
|
/// dimensions. The order of values is the same as in the tensor:
|
|
/// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param aalgorithm Pooling algorithm kind: either
|
|
/// #dnnl::algorithm::pooling_max,
|
|
/// #dnnl::algorithm::pooling_avg_include_padding,
|
|
/// or #dnnl::algorithm::pooling_avg_exclude_padding.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param kernel Vector of kernel spatial dimensions.
|
|
/// @param dilation Array of dilations for spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
algorithm aalgorithm, const memory::desc &src_desc,
|
|
const memory::desc &dst_desc, const memory::dims &strides,
|
|
const memory::dims &kernel, const memory::dims &dilation,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
memory::validate_dims(strides, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(kernel, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, src_desc.get_ndims() - 2);
|
|
memory::validate_dims(dilation, src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_pooling_forward_primitive_desc_create(
|
|
&pd, aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
convert_to_c(aalgorithm), src_desc.get(), dst_desc.get(),
|
|
&strides[0], &kernel[0], &dilation[0], &padding_l[0],
|
|
&padding_r[0], attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a descriptor for a pooling forward "
|
|
"propagation primitive");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a pooling forward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a pooling forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_kernel()const
|
|
memory::dims get_kernel() const { return base::get_kernel(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
pooling_forward() = default;
|
|
|
|
/// Constructs a pooling forward propagation primitive.
|
|
///
|
|
/// @param pd Primitive descriptor for a pooling forward propagation
|
|
/// primitive.
|
|
pooling_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a pooling forward propagation primitive from a cache blob.
|
|
///
|
|
/// @param pd Primitive descriptor for a pooling forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
pooling_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// Pooling backward propagation primitive.
|
|
struct pooling_backward : public primitive {
|
|
/// Primitive descriptor for a pooling backward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a pooling backward propagation
|
|
/// primitive.
|
|
///
|
|
/// Arrays @p strides, @p kernel, @p dilation, @p padding_l
|
|
/// and @p padding_r contain values for spatial dimensions only and
|
|
/// hence must have the same number of elements as there are spatial
|
|
/// dimensions. The order of values is the same as in the tensor:
|
|
/// depth (for 3D tensors), height (for 3D and 2D tensors), and width.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm Pooling algorithm kind: either
|
|
/// #dnnl::algorithm::pooling_max,
|
|
/// #dnnl::algorithm::pooling_avg_include_padding,
|
|
/// or #dnnl::algorithm::pooling_avg_exclude_padding.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param strides Vector of strides for spatial dimension.
|
|
/// @param kernel Vector of kernel spatial dimensions.
|
|
/// @param dilation Array of dilations for spatial dimension.
|
|
/// @param padding_l Vector of padding values for low indices for each
|
|
/// spatial dimension `([[front,] top,] left)`.
|
|
/// @param padding_r Vector of padding values for high indices for
|
|
/// each spatial dimension `([[back,] bottom,] right)`.
|
|
/// @param hint_fwd_pd Primitive descriptor for a pooling
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_dst_desc, const memory::dims &strides,
|
|
const memory::dims &kernel, const memory::dims &dilation,
|
|
const memory::dims &padding_l, const memory::dims &padding_r,
|
|
const pooling_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
memory::validate_dims(strides, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(kernel, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_l, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(padding_r, diff_src_desc.get_ndims() - 2);
|
|
memory::validate_dims(dilation, diff_src_desc.get_ndims() - 2);
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_pooling_backward_primitive_desc_create(
|
|
&pd, aengine.get(), convert_to_c(aalgorithm),
|
|
diff_src_desc.get(), diff_dst_desc.get(), &strides[0],
|
|
&kernel[0], &dilation[0], &padding_l[0], &padding_r[0],
|
|
hint_fwd_pd.get(), attr.get());
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a descriptor for a pooling backward "
|
|
"propagation primitive");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a pooling backward propagation
|
|
/// primitive from a C API primitive descriptor that must have a
|
|
/// matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a pooling backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
|
|
dnnl::prop_kind::backward_data) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::workspace_desc()const
|
|
memory::desc workspace_desc() const { return base::workspace_desc(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_strides()const
|
|
memory::dims get_strides() const { return base::get_strides(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_kernel()const
|
|
memory::dims get_kernel() const { return base::get_kernel(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_dilations()const
|
|
memory::dims get_dilations() const { return base::get_dilations(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_l()const
|
|
memory::dims get_padding_l() const { return base::get_padding_l(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_padding_r()const
|
|
memory::dims get_padding_r() const { return base::get_padding_r(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
pooling_backward() = default;
|
|
|
|
/// Constructs a pooling backward propagation primitive.
|
|
///
|
|
/// @param pd Primitive descriptor for a pooling backward propagation
|
|
/// primitive.
|
|
pooling_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a pooling backward propagation primitive from a cache blob.
|
|
///
|
|
/// @param pd Primitive descriptor for a pooling backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
pooling_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_pooling
|
|
|
|
/// @addtogroup dnnl_api_prelu PReLU
|
|
///
|
|
/// PReLU primitive
|
|
/// A primitive to perform PReLU (leaky ReLU with trainable alpha parameter)
|
|
///
|
|
/// @sa @ref dev_guide_prelu in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// PReLU forward propagation primitive.
|
|
struct prelu_forward : public primitive {
|
|
/// Primitive descriptor for a PReLU forward propagation primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a PReLU forward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aprop_kind Propagation kind. Possible values are
|
|
/// #dnnl::prop_kind::forward_training, and
|
|
/// #dnnl::prop_kind::forward_inference.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weight_desc Alpha parameters memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, prop_kind aprop_kind,
|
|
const memory::desc &src_desc, const memory::desc &weight_desc,
|
|
const memory::desc &dst_desc,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_prelu_forward_primitive_desc_create(&pd,
|
|
aengine.get(), dnnl::convert_to_c(aprop_kind),
|
|
src_desc.get(), weight_desc.get(), dst_desc.get(),
|
|
attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the prelu forward propagation primitive. Run workload "
|
|
"with environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a prelu forward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a prelu forward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
|
|
dnnl::prop_kind::forward_training,
|
|
dnnl::prop_kind::forward_inference) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
prelu_forward() = default;
|
|
|
|
/// Constructs a prelu forward propagation primitive.
|
|
/// @param pd Primitive descriptor for a prelu forward propagation
|
|
/// primitive.
|
|
prelu_forward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a prelu forward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a prelu forward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
prelu_forward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// PReLU backward propagation primitive.
|
|
struct prelu_backward : public primitive {
|
|
/// Primitive descriptor for prelu backward propagation.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a descriptor for a PReLU backward propagation
|
|
/// primitive.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param weight_desc Alpha parameters memory descriptor.
|
|
/// @param diff_src_desc Diff source memory descriptor.
|
|
/// @param diff_weights_desc Diff alpha parameters memory descriptor.
|
|
/// @param diff_dst_desc Diff destination memory descriptor.
|
|
/// @param hint_fwd_pd Primitive descriptor for a PReLU
|
|
/// forward propagation primitive. It is used as a hint for
|
|
/// deciding which memory format to use.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, const memory::desc &src_desc,
|
|
const memory::desc &weight_desc,
|
|
const memory::desc &diff_src_desc,
|
|
const memory::desc &diff_weights_desc,
|
|
const memory::desc &diff_dst_desc,
|
|
const prelu_forward::primitive_desc &hint_fwd_pd,
|
|
const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_prelu_backward_primitive_desc_create(
|
|
&pd, aengine.get(), src_desc.get(), weight_desc.get(),
|
|
diff_src_desc.get(), diff_weights_desc.get(),
|
|
diff_dst_desc.get(), hint_fwd_pd.get(), attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the prelu backward propagation primitive. Run "
|
|
"workload with environment variable ONEDNN_VERBOSE=all "
|
|
"to get additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a prelu backward
|
|
/// propagation primitive from a C API primitive descriptor that must
|
|
/// have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a prelu backward
|
|
/// propagation primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
|
|
dnnl::prop_kind::backward) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_src_desc()const
|
|
memory::desc diff_src_desc() const { return base::diff_src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::diff_dst_desc()const
|
|
memory::desc diff_dst_desc() const { return base::diff_dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_prop_kind()const
|
|
prop_kind get_prop_kind() const { return base::get_prop_kind(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
prelu_backward() = default;
|
|
|
|
/// Constructs a prelu backward propagation primitive.
|
|
/// @param pd Primitive descriptor for a prelu backward propagation
|
|
/// primitive.
|
|
prelu_backward(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a prelu backward propagation primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a prelu backward propagation
|
|
/// primitive.
|
|
/// @param cache_blob Cache blob.
|
|
prelu_backward(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_prelu
|
|
|
|
/// @addtogroup dnnl_api_reduction Reduction
|
|
///
|
|
/// A primitive to compute reduction operation on data tensor
|
|
/// using min, max, mul, sum, mean and norm_lp operations.
|
|
///
|
|
/// @sa @ref dev_guide_reduction in developer guide
|
|
///
|
|
/// @{
|
|
|
|
/// Reduction.
|
|
struct reduction : public primitive {
|
|
/// Primitive descriptor for a reduction primitive.
|
|
struct primitive_desc : public dnnl::primitive_desc {
|
|
/// Default constructor. Produces an empty object.
|
|
primitive_desc() = default;
|
|
|
|
/// Constructs a primitive descriptor for a reduction primitive using
|
|
/// algorithm specific parameters, source and destination memory
|
|
/// descriptors.
|
|
///
|
|
/// @note
|
|
/// Destination memory descriptor may be initialized with
|
|
/// #dnnl::memory::format_tag::any value of @p format_tag.
|
|
///
|
|
/// @param aengine Engine to use.
|
|
/// @param aalgorithm reduction algorithm kind. Possible values:
|
|
/// #dnnl_reduction_max, #dnnl_reduction_min, #dnnl_reduction_sum,
|
|
/// #dnnl_reduction_mul, #dnnl_reduction_mean,
|
|
/// #dnnl_reduction_norm_lp_max, #dnnl_reduction_norm_lp_sum,
|
|
/// #dnnl_reduction_norm_lp_power_p_max,
|
|
/// #dnnl_reduction_norm_lp_power_p_sum.
|
|
/// @param p algorithm specific parameter.
|
|
/// @param eps algorithm specific parameter.
|
|
/// @param src_desc Source memory descriptor.
|
|
/// @param dst_desc Destination memory descriptor.
|
|
/// @param attr Primitive attributes to use. Attributes are optional
|
|
/// and default to empty attributes.
|
|
/// @param allow_empty A flag signifying whether construction is
|
|
/// allowed to fail without throwing an exception. In this case an
|
|
/// empty object will be produced. This flag is optional and
|
|
/// defaults to false.
|
|
primitive_desc(const engine &aengine, algorithm aalgorithm,
|
|
const memory::desc &src_desc, const memory::desc &dst_desc,
|
|
float p, float eps, const primitive_attr &attr = default_attr(),
|
|
bool allow_empty = false) {
|
|
|
|
dnnl_primitive_desc_t pd = nullptr;
|
|
dnnl_status_t status = dnnl_reduction_primitive_desc_create(&pd,
|
|
aengine.get(), convert_to_c(aalgorithm), src_desc.get(),
|
|
dst_desc.get(), p, eps, attr.get());
|
|
|
|
if (!allow_empty)
|
|
error::wrap_c_api(status,
|
|
"could not create a primitive descriptor for "
|
|
"the reduction primitive. Run workload with "
|
|
"environment variable ONEDNN_VERBOSE=all to get "
|
|
"additional diagnostic information.");
|
|
reset(pd);
|
|
}
|
|
|
|
/// Constructs a primitive descriptor for a reduction primitive from a C
|
|
/// API primitive descriptor that must have a matching kind.
|
|
///
|
|
/// @param pd C API primitive descriptor for a reduction primitive.
|
|
primitive_desc(dnnl_primitive_desc_t pd)
|
|
: dnnl::primitive_desc(pd, dnnl::primitive::kind::reduction) {}
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::src_desc()const
|
|
memory::desc src_desc() const { return base::src_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::dst_desc()const
|
|
memory::desc dst_desc() const { return base::dst_desc(0); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_p()const
|
|
float get_p() const { return base::get_p(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_epsilon()const
|
|
float get_epsilon() const { return base::get_epsilon(); }
|
|
|
|
/// @copydoc dnnl::primitive_desc_base::get_algorithm()const
|
|
algorithm get_algorithm() const { return base::get_algorithm(); }
|
|
};
|
|
|
|
/// Default constructor. Produces an empty object.
|
|
reduction() = default;
|
|
|
|
/// Constructs a reduction primitive.
|
|
/// @param pd Primitive descriptor for a reduction primitive.
|
|
reduction(const primitive_desc &pd) : primitive(pd) {}
|
|
|
|
/// Constructs a reduction primitive from a cache blob.
|
|
/// @param pd Primitive descriptor for a reduction primitive.
|
|
/// @param cache_blob Cache blob.
|
|
reduction(const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd, cache_blob) {}
|
|
};
|
|
|
|
/// @} dnnl_api_reduction
|
|
|
|
/// @} dnnl_api_primitives
|
|
|
|
/// @addtogroup dnnl_api_service Service
|
|
///
|
|
/// A set of functions that aid in oneDNN debugging and profiling.
|
|
///
|
|
/// @{
|
|
|
|
/// @copydoc dnnl_version_t
|
|
using version_t = dnnl_version_t;
|
|
|
|
/// Status values returned by the library functions.
|
|
enum class status {
|
|
/// @copydoc dnnl_success
|
|
success = dnnl_success,
|
|
/// @copydoc dnnl_out_of_memory
|
|
out_of_memory = dnnl_out_of_memory,
|
|
/// @copydoc dnnl_invalid_arguments
|
|
invalid_arguments = dnnl_invalid_arguments,
|
|
/// @copydoc dnnl_unimplemented
|
|
unimplemented = dnnl_unimplemented,
|
|
/// @copydoc dnnl_last_impl_reached
|
|
last_impl_reached = dnnl_last_impl_reached,
|
|
/// @copydoc dnnl_runtime_error
|
|
runtime_error = dnnl_runtime_error,
|
|
/// @copydoc dnnl_not_required
|
|
not_required = dnnl_not_required,
|
|
};
|
|
|
|
/// @copydoc dnnl_set_verbose()
|
|
inline status set_verbose(int level) {
|
|
return static_cast<status>(dnnl_set_verbose(level));
|
|
}
|
|
|
|
/// @copydoc dnnl_version()
|
|
inline const version_t *version() {
|
|
return dnnl_version();
|
|
}
|
|
|
|
/// Returns the floating-point math mode that will be used by default
|
|
/// for all subsequently created primitives.
|
|
///
|
|
/// @returns Output FP math mode.
|
|
inline fpmath_mode get_default_fpmath_mode() {
|
|
dnnl_fpmath_mode_t mode;
|
|
error::wrap_c_api(dnnl_get_default_fpmath_mode(&mode),
|
|
"could not get a default fpmath mode");
|
|
return static_cast<fpmath_mode>(mode);
|
|
}
|
|
|
|
/// @copydoc dnnl_set_default_fpmath_mode()
|
|
inline status set_default_fpmath_mode(fpmath_mode mode) {
|
|
return static_cast<status>(
|
|
dnnl_set_default_fpmath_mode(convert_to_c(mode)));
|
|
}
|
|
|
|
/// @copydoc dnnl_set_jit_dump()
|
|
inline status set_jit_dump(int enable) {
|
|
return static_cast<status>(dnnl_set_jit_dump(enable));
|
|
}
|
|
|
|
/// @copydoc dnnl_set_jit_profiling_flags()
|
|
inline status set_jit_profiling_flags(unsigned flags) {
|
|
return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
|
|
}
|
|
|
|
/// @copydoc dnnl_set_jit_profiling_jitdumpdir()
|
|
inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
|
|
return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
|
|
}
|
|
|
|
/// @copydoc dnnl_cpu_isa_t
|
|
enum class cpu_isa {
|
|
/// @copydoc dnnl_cpu_isa_default
|
|
isa_default = dnnl_cpu_isa_default,
|
|
/// @copydoc dnnl_cpu_isa_sse41
|
|
sse41 = dnnl_cpu_isa_sse41,
|
|
/// @copydoc dnnl_cpu_isa_avx
|
|
avx = dnnl_cpu_isa_avx,
|
|
/// @copydoc dnnl_cpu_isa_avx2
|
|
avx2 = dnnl_cpu_isa_avx2,
|
|
/// @copydoc dnnl_cpu_isa_avx2_vnni
|
|
avx2_vnni = dnnl_cpu_isa_avx2_vnni,
|
|
/// @copydoc dnnl_cpu_isa_avx2_vnni_2
|
|
avx2_vnni_2 = dnnl_cpu_isa_avx2_vnni_2,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core
|
|
avx512_core = dnnl_cpu_isa_avx512_core,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core_vnni
|
|
avx512_core_vnni = dnnl_cpu_isa_avx512_core_vnni,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core_bf16
|
|
avx512_core_bf16 = dnnl_cpu_isa_avx512_core_bf16,
|
|
/// @copydoc dnnl_cpu_isa_avx10_1_512
|
|
avx10_1_512 = dnnl_cpu_isa_avx10_1_512,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core_fp16
|
|
avx512_core_fp16 = dnnl_cpu_isa_avx512_core_fp16,
|
|
/// @copydoc dnnl_cpu_isa_avx10_1_512_amx
|
|
avx10_1_512_amx = dnnl_cpu_isa_avx10_1_512_amx,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core_amx
|
|
avx512_core_amx = dnnl_cpu_isa_avx512_core_amx,
|
|
/// @copydoc dnnl_cpu_isa_avx10_1_512_amx_fp16
|
|
avx10_1_512_amx_fp16 = dnnl_cpu_isa_avx10_1_512_amx_fp16,
|
|
/// @copydoc dnnl_cpu_isa_avx512_core_amx_fp16
|
|
avx512_core_amx_fp16 = dnnl_cpu_isa_avx512_core_amx_fp16,
|
|
/// @copydoc dnnl_cpu_isa_avx10_2_512
|
|
avx10_2_512 = dnnl_cpu_isa_avx10_2_512,
|
|
/// @copydoc dnnl_cpu_isa_avx10_2_512_amx_2
|
|
avx10_2_512_amx_2 = dnnl_cpu_isa_avx10_2_512_amx_2,
|
|
};
|
|
|
|
/// @copydoc dnnl_set_max_cpu_isa()
|
|
inline status set_max_cpu_isa(cpu_isa isa) {
|
|
return static_cast<status>(
|
|
dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
|
|
}
|
|
|
|
/// @copydoc dnnl_get_effective_cpu_isa()
|
|
inline cpu_isa get_effective_cpu_isa() {
|
|
return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
|
|
}
|
|
|
|
/// @copydoc dnnl_cpu_isa_hints_t
|
|
enum class cpu_isa_hints {
|
|
/// @copydoc dnnl_cpu_isa_no_hints
|
|
no_hints = dnnl_cpu_isa_no_hints,
|
|
/// @copydoc dnnl_cpu_isa_prefer_ymm
|
|
prefer_ymm = dnnl_cpu_isa_prefer_ymm,
|
|
};
|
|
|
|
/// @copydoc dnnl_set_cpu_isa_hints()
|
|
inline status set_cpu_isa_hints(cpu_isa_hints isa_hints) {
|
|
return static_cast<status>(dnnl_set_cpu_isa_hints(
|
|
static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
|
|
}
|
|
|
|
/// @copydoc dnnl_get_cpu_isa_hints()
|
|
inline cpu_isa_hints get_cpu_isa_hints() {
|
|
return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
|
|
}
|
|
|
|
/// @} dnnl_api_service
|
|
|
|
#ifdef DNNL_EXPERIMENTAL_PROFILING
|
|
/// @addtogroup dnnl_api_profiling Profiling
|
|
/// @{
|
|
|
|
/// Profiling data kind.
|
|
enum class profiling_data_kind {
|
|
/// Undefined profiling data kind.
|
|
undef = dnnl_profiling_data_kind_undef,
|
|
/// Data kind to query an execution time in nanoseconds.
|
|
time = dnnl_profiling_data_kind_time,
|
|
};
|
|
|
|
/// Resets a profiler's state.
|
|
///
|
|
/// @param stream Stream associated with the profiler.
|
|
inline void reset_profiling(stream &stream) {
|
|
error::wrap_c_api(
|
|
dnnl_reset_profiling(stream.get()), "could not reset profiling");
|
|
}
|
|
|
|
/// Returns requested profiling data. The profiling data accumulates for each
|
|
/// primitive execution. The size of the vector will be equal to the number
|
|
/// of executions since the last `dnnl::reset_profiling` call.
|
|
///
|
|
/// The profiling data can be reset by calling #dnnl::reset_profiling.
|
|
///
|
|
/// @note
|
|
/// It is required to wait for all submitted primitives to complete
|
|
/// using #dnnl::stream::wait prior to querying profiling data.
|
|
///
|
|
/// @param stream Stream that was used for executing a primitive that
|
|
/// is being profiled.
|
|
/// @param data_kind Profiling data kind to query.
|
|
///
|
|
/// @returns A vector with the requested profiling data.
|
|
inline std::vector<uint64_t> get_profiling_data(
|
|
stream &stream, profiling_data_kind data_kind) {
|
|
int num_entries = 0;
|
|
error::wrap_c_api(
|
|
dnnl_query_profiling_data(stream.get(),
|
|
static_cast<dnnl_profiling_data_kind_t>(data_kind),
|
|
&num_entries, nullptr),
|
|
"could not get number of entries for profiling data");
|
|
|
|
if (num_entries == 0) return {};
|
|
|
|
std::vector<uint64_t> data(num_entries);
|
|
error::wrap_c_api(
|
|
dnnl_query_profiling_data(stream.get(),
|
|
static_cast<dnnl_profiling_data_kind_t>(data_kind),
|
|
&num_entries, data.data()),
|
|
"could not get profiling data");
|
|
return data;
|
|
}
|
|
|
|
/// @} dnnl_api_profiling
|
|
#endif
|
|
|
|
/// @addtogroup dnnl_api_primitive_cache Primitive Cache
|
|
///
|
|
/// A set of functions that provide primitive cache control.
|
|
///
|
|
/// @{
|
|
|
|
/// Returns the number of primitives that can be held in the primitive cache
|
|
/// at the same time.
|
|
inline int get_primitive_cache_capacity() {
|
|
int result = 0;
|
|
error::wrap_c_api(dnnl_get_primitive_cache_capacity(&result),
|
|
"could not get primitive cache capacity");
|
|
return result;
|
|
}
|
|
|
|
/// @copydoc dnnl_set_primitive_cache_capacity(int capacity)
|
|
inline void set_primitive_cache_capacity(int capacity) {
|
|
error::wrap_c_api(dnnl_set_primitive_cache_capacity(capacity),
|
|
"could not set primitive cache capacity");
|
|
}
|
|
|
|
/// @} dnnl_api_primitive_cache
|
|
|
|
/// @addtogroup dnnl_api_blas BLAS functions
|
|
///
|
|
/// A subset of Basic Linear Algebra (BLAS) functions that perform
|
|
/// matrix-matrix multiplication.
|
|
///
|
|
/// @{
|
|
|
|
/// @copydoc dnnl_sgemm()
|
|
inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
|
|
dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
|
|
const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
|
|
return static_cast<status>(dnnl_sgemm(
|
|
transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
|
|
}
|
|
|
|
/// @copydoc dnnl_gemm_u8s8s32()
|
|
inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
|
|
dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
|
|
return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
|
|
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
|
|
}
|
|
|
|
/// @copydoc dnnl_gemm_s8s8s32()
|
|
inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
|
|
dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
|
|
dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
|
|
float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
|
|
return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
|
|
K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
|
|
}
|
|
|
|
/// @} dnnl_api_blas
|
|
|
|
// implementation section
|
|
|
|
/// @cond DO_NOT_DOCUMENT_THIS
|
|
inline primitive::primitive(const_dnnl_primitive_desc_t c_pd) {
|
|
dnnl_primitive_t result;
|
|
error::wrap_c_api(dnnl_primitive_create(&result, c_pd),
|
|
"could not create a primitive");
|
|
reset(result);
|
|
}
|
|
|
|
inline primitive::primitive(const_dnnl_primitive_desc_t c_pd,
|
|
const std::vector<uint8_t> &cache_blob) {
|
|
dnnl_primitive_t result;
|
|
size_t size = cache_blob.size();
|
|
const uint8_t *cache_blob_data = cache_blob.data();
|
|
error::wrap_c_api(dnnl_primitive_create_from_cache_blob(
|
|
&result, c_pd, size, cache_blob_data),
|
|
"could not create a primitive from a cache blob");
|
|
reset(result);
|
|
}
|
|
|
|
inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
|
|
inline primitive::primitive(
|
|
const primitive_desc &pd, const std::vector<uint8_t> &cache_blob)
|
|
: primitive(pd.get(), cache_blob) {}
|
|
|
|
inline void primitive::execute(const stream &astream,
|
|
const std::unordered_map<int, memory> &args) const {
|
|
std::vector<dnnl_exec_arg_t> c_args;
|
|
c_args.reserve(args.size());
|
|
for (const auto &a : args)
|
|
c_args.push_back({a.first, a.second.get(true)});
|
|
|
|
error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
|
|
(int)c_args.size(), c_args.data()),
|
|
"could not execute a primitive");
|
|
}
|
|
|
|
/// @endcond
|
|
|
|
#undef DNNL_DEFINE_BITMASK_OPS
|
|
|
|
} // namespace dnnl
|
|
|
|
/// oneAPI namespace
|
|
|
|
/// The oneAPI namespace.
|
|
/// Contains the oneapi::dnnl namespace as an alias to the ::dnnl namespace.
|
|
namespace oneapi {
|
|
// Note: without this guard, doxygen warns of potentially recursive namespace
|
|
#ifndef DOXYGEN_SHOULD_SKIP_THIS
|
|
/// oneDNN alias namespace
|
|
namespace dnnl = ::dnnl;
|
|
#endif
|
|
} // namespace oneapi
|
|
|
|
/// @} dnnl_api
|
|
|
|
// NOLINTEND(readability-identifier-naming)
|
|
#endif /* ONEAPI_DNNL_DNNL_HPP */
|