mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
common: introduce quantization mode
This commit is contained in:
@ -481,6 +481,35 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_v2(
|
||||
const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
|
||||
int is_on_host);
|
||||
|
||||
/// Sets primitive attributes scaling factors for primitive operations for a
|
||||
/// given memory argument. The scaling factors must be passed at execution time
|
||||
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
|
||||
/// @sa dnnl_primitive_attr_set_scales
|
||||
///
|
||||
/// @param attr Primitive attributes.
|
||||
/// @param arg Parameter argument index as passed to the
|
||||
/// dnnl_primitive_execute() call.
|
||||
/// @param mask Scaling factors correspondence mask that defines the
|
||||
/// correspondence between the tensor dimensions and the @p scales array.
|
||||
/// The set i-th bit indicates that a dedicated scaling factor is used for
|
||||
/// each index along that dimension. Set the mask to 0 to use a common
|
||||
/// scaling factor for the whole tensor.
|
||||
/// @param ndims Number of group dimensions.
|
||||
/// @param group_dims Scaling factors correspondence groups that define the
|
||||
/// correspondence between the tensor dimensions and the scales array.
|
||||
/// The group dimensions should only be provided for each logical dimension
|
||||
/// that has correspondence mask @p mask set.
|
||||
/// @param data_type Scaling factors data_type.
|
||||
/// @param is_on_host Indicates whether the scale is a host-side scalar.
|
||||
/// @param qmode Quantization mode, can be #dnnl_quantization_mode_static_sazp
|
||||
/// or #dnnl_quantization_mode_dynamic_mx
|
||||
/// @returns #dnnl_success on success and a status describing the error
|
||||
/// otherwise.
|
||||
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_v3(
|
||||
dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
|
||||
const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
|
||||
int is_on_host, dnnl_quantization_mode_t qmode);
|
||||
|
||||
/// Sets primitive attributes zero points for primitive operations for a given
|
||||
/// memory argument. The zero points must be passed at execution time
|
||||
/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.
|
||||
|
@ -320,6 +320,27 @@ inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
|
||||
return static_cast<dnnl_rounding_mode_t>(mode);
|
||||
}
|
||||
|
||||
/// Quantization kind
|
||||
enum class quantization_mode {
|
||||
/// used for unspecified quantization kind
|
||||
undef = dnnl_quantization_mode_undef,
|
||||
/// static quantization mode: quantization parameter is computed
|
||||
/// ahead of time and passed to oneDNN as an input.
|
||||
static_sazp = dnnl_quantization_mode_static_sazp,
|
||||
/// dynamic quantization mode following OCP MX spec: quantization
|
||||
/// parameter is computed by oneDNN following the OCP MX spec
|
||||
/// formula and written as an output.
|
||||
dynamic_mx = dnnl_quantization_mode_dynamic_mx,
|
||||
};
|
||||
|
||||
/// Converts a quantization kind enum value from C++ API to C API type.
|
||||
///
|
||||
/// @param mode C++ API quantization kind enum value.
|
||||
/// @returns Corresponding C API quantization kind enum value.
|
||||
inline dnnl_quantization_mode_t convert_to_c(quantization_mode qmode) {
|
||||
return static_cast<dnnl_quantization_mode_t>(qmode);
|
||||
}
|
||||
|
||||
/// Propagation kind.
|
||||
enum class prop_kind {
|
||||
/// Undefined propagation kind.
|
||||
@ -4172,36 +4193,35 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
|
||||
"could not set scales primitive attribute");
|
||||
}
|
||||
|
||||
/// Sets scaling factors for primitive operations for a given memory
|
||||
/// argument. The scaling factors must be passed at execution time
|
||||
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
|
||||
/// Sets primitive attributes scaling factors for a given memory
|
||||
/// argument. The scaling factors must be passed at execution time as
|
||||
/// an argument with index #DNNL_ARG_ATTR_SCALES | arg.
|
||||
///
|
||||
/// @note If `is_on_host` is true, sets a single host-side scalar scaling
|
||||
/// factor for the specified memory argument. The scaling factor should be
|
||||
/// passed as a host scalar memory object at execution time with index
|
||||
/// #DNNL_ARG_ATTR_SCALES | arg.
|
||||
///
|
||||
/// @sa dnnl_primitive_attr_set_scales_v2
|
||||
/// @sa dnnl_primitive_attr_set_scales_v3
|
||||
///
|
||||
/// @param arg Parameter argument index as passed to the
|
||||
/// primitive::execute() call.
|
||||
/// @param mask Scales correspondence mask that defines the
|
||||
/// correspondence between the tensor dimensions and the @p
|
||||
/// scales vector. The set i-th bit indicates that a dedicated
|
||||
/// scale is used for each index along that dimension. Set the
|
||||
/// mask to 0 to use a common scale for the whole output tensor.
|
||||
/// primitive execute() call.
|
||||
/// @param mask Scaling factors correspondence mask that defines the
|
||||
/// correspondence between the tensor dimensions and the @p scales array.
|
||||
/// The set i-th bit indicates that a dedicated scaling factor is used for
|
||||
/// each index along that dimension. Set the mask to 0 to use a common
|
||||
/// scaling factor for the whole tensor.
|
||||
/// @param groups Scaling factors correspondence groups that define the
|
||||
/// correspondence between the tensor dimensions and the scales array.
|
||||
/// The set i-th dimension indicates a number of groups of scaling
|
||||
/// factors used for that logical dimension in a memory indicated by @p arg.
|
||||
/// The group dimensions should only be provided for each logical dimension
|
||||
/// that has correspondence mask @p mask set.
|
||||
/// @param data_type Scaling factors data_type.
|
||||
/// @param is_on_host Indicates whether the scaling factor is a host-side scalar.
|
||||
/// @param qmode Quantization mode, can be #quantization_mode::static_sazp
|
||||
/// or #quantization_mode::dynamic_mx
|
||||
void set_scales(int arg, int mask, const memory::dims &groups,
|
||||
memory::data_type data_type = memory::data_type::f32,
|
||||
bool is_on_host = false) {
|
||||
error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, mask,
|
||||
bool is_on_host = false,
|
||||
quantization_mode qmode = quantization_mode::static_sazp) {
|
||||
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, mask,
|
||||
(int)groups.size(), groups.data(),
|
||||
memory::convert_to_c(data_type), is_on_host),
|
||||
memory::convert_to_c(data_type), is_on_host,
|
||||
convert_to_c(qmode)),
|
||||
"could not set scales primitive attribute");
|
||||
}
|
||||
|
||||
@ -4220,8 +4240,9 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
|
||||
/// @param data_type Scaling factors data_type.
|
||||
void set_host_scale(
|
||||
int arg, memory::data_type data_type = memory::data_type::f32) {
|
||||
error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, 0, 0,
|
||||
nullptr, memory::convert_to_c(data_type), 1),
|
||||
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, 0, 0,
|
||||
nullptr, memory::convert_to_c(data_type), 1,
|
||||
dnnl_quantization_mode_static_sazp),
|
||||
"could not set scales primitive attribute");
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,7 @@ const char DNNL_API *dnnl_rnn_flags2str(dnnl_rnn_flags_t v);
|
||||
const char DNNL_API *dnnl_rnn_direction2str(dnnl_rnn_direction_t v);
|
||||
const char DNNL_API *dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v);
|
||||
const char DNNL_API *dnnl_rounding_mode2str(dnnl_rounding_mode_t v);
|
||||
const char DNNL_API *dnnl_quantization_mode2str(dnnl_quantization_mode_t v);
|
||||
const char DNNL_API *dnnl_cpu_isa2str(dnnl_cpu_isa_t v);
|
||||
const char DNNL_API *dnnl_cpu_isa_hints2str(dnnl_cpu_isa_hints_t v);
|
||||
|
||||
|
@ -2424,6 +2424,21 @@ typedef enum {
|
||||
dnnl_rounding_mode_stochastic,
|
||||
} dnnl_rounding_mode_t;
|
||||
|
||||
/// Quantization kind
|
||||
typedef enum {
|
||||
/// used for unspecified quantization kind
|
||||
dnnl_quantization_mode_undef,
|
||||
/// static quantization mode: quantization parameter is computed
|
||||
/// ahead of time with scale applied after zero-point (\f$x_{f32}
|
||||
/// = scale * (x_{quant} - zp)\f$) and passed to oneDNN as an
|
||||
/// input.
|
||||
dnnl_quantization_mode_static_sazp,
|
||||
/// dynamic quantization mode following OCP MX spec: quantization
|
||||
/// parameter is computed by oneDNN following the OCP MX spec
|
||||
/// formula and written as an output.
|
||||
dnnl_quantization_mode_dynamic_mx,
|
||||
} dnnl_quantization_mode_t;
|
||||
|
||||
/// @struct dnnl_primitive_attr
|
||||
/// @brief An opaque structure for primitive descriptor attributes.
|
||||
///
|
||||
|
@ -213,6 +213,7 @@ def sanitize_value(v):
|
||||
v = v.split("dnnl_accumulation_mode_")[-1]
|
||||
v = v.split("dnnl_rounding_mode_")[-1]
|
||||
v = v.split("dnnl_scratchpad_mode_")[-1]
|
||||
v = v.split("dnnl_quantization_mode_")[-1]
|
||||
v = v.split("dnnl_")[-1]
|
||||
return v
|
||||
|
||||
|
@ -211,6 +211,13 @@ const rounding_mode_t environment = dnnl_rounding_mode_environment;
|
||||
const rounding_mode_t stochastic = dnnl_rounding_mode_stochastic;
|
||||
} // namespace rounding_mode
|
||||
|
||||
using quantization_mode_t = dnnl_quantization_mode_t;
|
||||
namespace quantization_mode {
|
||||
const quantization_mode_t undef = dnnl_quantization_mode_undef;
|
||||
const quantization_mode_t static_sazp = dnnl_quantization_mode_static_sazp;
|
||||
const quantization_mode_t dynamic_mx = dnnl_quantization_mode_dynamic_mx;
|
||||
} // namespace quantization_mode
|
||||
|
||||
using sparse_encoding_t = dnnl_sparse_encoding_t;
|
||||
namespace sparse_encoding {
|
||||
const sparse_encoding_t undef = dnnl_sparse_encoding_undef;
|
||||
|
@ -1899,6 +1899,14 @@ const char *dnnl_rounding_mode2str(dnnl_rounding_mode_t v) {
|
||||
return "unknown rounding_mode";
|
||||
}
|
||||
|
||||
const char *dnnl_quantization_mode2str(dnnl_quantization_mode_t v) {
|
||||
if (v == dnnl_quantization_mode_undef) return "undef";
|
||||
if (v == dnnl_quantization_mode_static_sazp) return "static_sazp";
|
||||
if (v == dnnl_quantization_mode_dynamic_mx) return "dynamic_mx";
|
||||
assert(!"unknown quantization_mode");
|
||||
return "unknown quantization_mode";
|
||||
}
|
||||
|
||||
const char *dnnl_cpu_isa2str(dnnl_cpu_isa_t v) {
|
||||
if (v == dnnl_cpu_isa_default) return "cpu_isa_default";
|
||||
if (v == dnnl_cpu_isa_sse41) return "cpu_isa_sse41";
|
||||
|
@ -577,16 +577,26 @@ status_t dnnl_primitive_attr_set_scales_mask(
|
||||
status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,
|
||||
int mask, int group_ndims, const dims_t group_dims,
|
||||
data_type_t data_type) {
|
||||
return dnnl_primitive_attr_set_scales_v2(
|
||||
attr, arg, mask, group_ndims, group_dims, data_type, 0);
|
||||
return dnnl_primitive_attr_set_scales_v3(attr, arg, mask, group_ndims,
|
||||
group_dims, data_type, 0, quantization_mode::static_sazp);
|
||||
}
|
||||
|
||||
status_t dnnl_primitive_attr_set_scales_v2(primitive_attr_t *attr, int arg,
|
||||
int mask, int group_ndims, const dims_t group_dims,
|
||||
data_type_t data_type, int is_on_host) {
|
||||
return dnnl_primitive_attr_set_scales_v3(attr, arg, mask, group_ndims,
|
||||
group_dims, data_type, is_on_host, quantization_mode::static_sazp);
|
||||
}
|
||||
|
||||
status_t dnnl_primitive_attr_set_scales_v3(primitive_attr_t *attr, int arg,
|
||||
int mask, int group_ndims, const dims_t group_dims,
|
||||
data_type_t data_type, int is_on_host, quantization_mode_t qmode) {
|
||||
using namespace data_type;
|
||||
VCHECK_ATTR(attr, VERBOSE_NULL_ARG);
|
||||
VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg");
|
||||
VCHECK_ATTR(utils::one_of(qmode, quantization_mode::static_sazp,
|
||||
quantization_mode::dynamic_mx),
|
||||
VERBOSE_BAD_PARAM, "qmode");
|
||||
VCHECK_ATTR(
|
||||
utils::one_of(data_type, f32, bf16, f16, e8m0, f8_e5m2, f8_e4m3),
|
||||
VERBOSE_INVALID_DATATYPE, "scales");
|
||||
@ -596,12 +606,12 @@ status_t dnnl_primitive_attr_set_scales_v2(primitive_attr_t *attr, int arg,
|
||||
if (is_on_host) { // only single value host-side scale is supported
|
||||
VCHECK_ATTR(mask == 0, VERBOSE_BAD_PARAM, "mask");
|
||||
VCHECK_ATTR(group_ndims == 0, VERBOSE_BAD_PARAM, "group_ndims");
|
||||
return attr->scales_.set(arg, 0, data_type, 0, {}, true);
|
||||
} else {
|
||||
VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask");
|
||||
VCHECK_ATTR(group_ndims >= 0, VERBOSE_BAD_PARAM, "group_ndims");
|
||||
return attr->scales_.set(arg, mask, data_type, group_ndims, group_dims);
|
||||
}
|
||||
return attr->scales_.set(
|
||||
arg, mask, data_type, group_ndims, group_dims, is_on_host, qmode);
|
||||
}
|
||||
|
||||
status_t dnnl_primitive_attr_set_zero_points_mask(
|
||||
|
@ -35,6 +35,7 @@ size_t quant_entry_t::get_hash() const {
|
||||
seed = primitive_hashing::get_array_hash(
|
||||
seed, group_dims_, group_ndims_);
|
||||
seed = hash_combine(seed, is_host_scalar_);
|
||||
seed = hash_combine(seed, qmode_);
|
||||
return seed;
|
||||
}
|
||||
|
||||
@ -43,6 +44,7 @@ void quant_entry_t::serialize(serialization_stream_t &sstream) const {
|
||||
sstream.append(data_type_);
|
||||
sstream.append_array(group_ndims_, group_dims_);
|
||||
sstream.append(is_host_scalar_);
|
||||
sstream.append(qmode_);
|
||||
}
|
||||
|
||||
quant_entry_t quant_entry_t::deserialize(deserializer_t &d) {
|
||||
@ -53,6 +55,7 @@ quant_entry_t quant_entry_t::deserialize(deserializer_t &d) {
|
||||
d.pop_array(group_ndims, e.group_dims_);
|
||||
e.group_ndims_ = static_cast<int>(group_ndims);
|
||||
d.pop(e.is_host_scalar_);
|
||||
d.pop(e.qmode_);
|
||||
return e;
|
||||
}
|
||||
|
||||
@ -67,6 +70,9 @@ std::string quant_entry_t::get_verbose() const {
|
||||
.append(std::to_string(group_dims_[1]));
|
||||
}
|
||||
if (is_host_scalar_) { s.append(":host_scalar"); }
|
||||
if (qmode_ != quantization_mode::static_sazp) {
|
||||
s.append(":").append(dnnl_quantization_mode2str(qmode_));
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
|
@ -57,7 +57,8 @@ struct quant_entry_t : public c_compatible {
|
||||
return set(mask, data_type, 0, {});
|
||||
}
|
||||
status_t set(int mask, data_type_t data_type, int group_ndims,
|
||||
const dims_t group_dims, bool is_host_scalar = false) {
|
||||
const dims_t group_dims, bool is_host_scalar = false,
|
||||
quantization_mode_t qmode = quantization_mode::static_sazp) {
|
||||
mask_ = mask;
|
||||
data_type_ = data_type;
|
||||
group_ndims_ = group_ndims;
|
||||
@ -65,11 +66,12 @@ struct quant_entry_t : public c_compatible {
|
||||
utils::array_copy(group_dims_, group_dims, group_ndims_);
|
||||
}
|
||||
is_host_scalar_ = is_host_scalar;
|
||||
qmode_ = qmode;
|
||||
return status::success;
|
||||
}
|
||||
status_t set(const quant_entry_t &other) {
|
||||
return set(other.mask_, other.data_type_, other.group_ndims_,
|
||||
other.group_dims_, other.is_host_scalar());
|
||||
other.group_dims_, other.is_host_scalar(), other.qmode_);
|
||||
}
|
||||
|
||||
quant_entry_t &operator=(const quant_entry_t &rhs) {
|
||||
@ -89,12 +91,16 @@ struct quant_entry_t : public c_compatible {
|
||||
dim_t get_group(int d) const {
|
||||
// If groups were not requested, return `1` for convenience.
|
||||
if (group_ndims_ == default_quant_entry().group_ndims_) return 1;
|
||||
// But if they were, any out of bound access would return `0` and likely
|
||||
// we allow negative indexes to address from last to first
|
||||
if (d < 0) d += group_ndims_;
|
||||
// Any out of bound access would return `0` and likely
|
||||
// lead to a division by zero which is fast to catch.
|
||||
if (d >= group_ndims_) return 0;
|
||||
if (d >= group_ndims_ || d < 0) return 0;
|
||||
return group_dims_[d];
|
||||
}
|
||||
bool is_host_scalar() const { return is_host_scalar_; }
|
||||
quantization_mode_t get_quantization_mode() const { return qmode_; }
|
||||
bool is_mx() const { return qmode_ == quantization_mode::dynamic_mx; }
|
||||
|
||||
status_t get_md(memory_desc_t &out_md, const memory_desc_t &base_md) const {
|
||||
if (has_default_values()) {
|
||||
@ -137,6 +143,7 @@ struct quant_entry_t : public c_compatible {
|
||||
&& IMPLICATION(group_ndims_ > 0,
|
||||
utils::array_cmp(
|
||||
group_dims_, rhs.group_dims_, group_ndims_))
|
||||
&& qmode_ == rhs.qmode_
|
||||
&& is_host_scalar_ == rhs.is_host_scalar_;
|
||||
}
|
||||
|
||||
@ -157,6 +164,7 @@ private:
|
||||
int group_ndims_ = 0;
|
||||
dims_t group_dims_ {};
|
||||
bool is_host_scalar_ = false;
|
||||
quantization_mode_t qmode_ = quantization_mode::undef;
|
||||
};
|
||||
|
||||
std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e);
|
||||
@ -177,10 +185,11 @@ struct quant_entries_t : public c_compatible {
|
||||
return set(arg, mask, default_data_type_, 0, {});
|
||||
}
|
||||
status_t set(int arg, int mask, data_type_t data_type, int group_ndims,
|
||||
const dims_t group_dims, bool is_host_scalar = false) {
|
||||
const dims_t group_dims, bool is_host_scalar = false,
|
||||
quantization_mode_t qmode = quantization_mode::static_sazp) {
|
||||
if (!check_arg(arg)) return status::invalid_arguments;
|
||||
CHECK(entries_[arg].set(
|
||||
mask, data_type, group_ndims, group_dims, is_host_scalar));
|
||||
CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims,
|
||||
is_host_scalar, qmode));
|
||||
return status::success;
|
||||
}
|
||||
// Use this interface with `default_quant_entry` when need to remove a
|
||||
|
@ -170,10 +170,15 @@ struct primitive_desc_t : public c_compatible {
|
||||
}
|
||||
if (arg & DNNL_ARG_ATTR_SCALES) {
|
||||
int scale_arg = arg & ~DNNL_ARG_ATTR_SCALES;
|
||||
return !attr()->scales_.has_default_values(scale_arg)
|
||||
? arg_usage_t::input
|
||||
: arg_usage_t::unused;
|
||||
if (!attr()->scales_.has_default_values(scale_arg)) {
|
||||
if (attr()->scales_.get(scale_arg).is_mx())
|
||||
return arg_usage_t::output;
|
||||
else
|
||||
return arg_usage_t::input;
|
||||
} else
|
||||
return arg_usage_t::unused;
|
||||
}
|
||||
|
||||
if (arg == DNNL_ARG_SCRATCHPAD)
|
||||
return !is_zero_md(scratchpad_md()) ? arg_usage_t::output
|
||||
: arg_usage_t::unused;
|
||||
|
@ -71,7 +71,8 @@ status_t cvt_primitive_args(const primitive_desc_t *pd, int nargs,
|
||||
args[arg] = {mem, false};
|
||||
n_outputs++;
|
||||
extra_outputs += (arg == DNNL_ARG_SCRATCHPAD)
|
||||
|| (arg == DNNL_ARG_ATTR_DROPOUT_MASK);
|
||||
|| (arg == DNNL_ARG_ATTR_DROPOUT_MASK)
|
||||
|| (arg & DNNL_ARG_ATTR_SCALES);
|
||||
break;
|
||||
case primitive_desc_t::arg_usage_t::unused:
|
||||
VINFO(primitive, exec, check, primitive,
|
||||
|
@ -187,6 +187,7 @@ void serialize(serialization_stream_t &sstream, const primitive_attr_t &attr) {
|
||||
// acc_mode
|
||||
sstream.append(attr.acc_mode_);
|
||||
|
||||
// scales
|
||||
if (!attr.scales_.has_default_values()) {
|
||||
sstream.append('s');
|
||||
attr.scales_.serialize(sstream);
|
||||
|
Reference in New Issue
Block a user