common: introduce quantization mode

This commit is contained in:
Mourad Gouicem
2025-01-14 02:07:26 -08:00
parent 98ae9938af
commit 1934858751
13 changed files with 151 additions and 37 deletions

View File

@ -481,6 +481,35 @@ dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_v2(
const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
int is_on_host);
/// Sets primitive attributes scaling factors for primitive operations for a
/// given memory argument. The scaling factors must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
/// @sa dnnl_primitive_attr_set_scales
///
/// @param attr Primitive attributes.
/// @param arg Parameter argument index as passed to the
/// dnnl_primitive_execute() call.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p scales array.
/// The set i-th bit indicates that a dedicated scaling factor is used for
/// each index along that dimension. Set the mask to 0 to use a common
/// scaling factor for the whole tensor.
/// @param ndims Number of group dimensions.
/// @param group_dims Scaling factors correspondence groups that define the
/// correspondence between the tensor dimensions and the scales array.
/// The group dimensions should only be provided for each logical dimension
/// that has correspondence mask @p mask set.
/// @param data_type Scaling factors data_type.
/// @param is_on_host Indicates whether the scale is a host-side scalar.
/// @param qmode Quantization mode, can be #dnnl_quantization_mode_static_sazp
/// or #dnnl_quantization_mode_dynamic_mx
/// @returns #dnnl_success on success and a status describing the error
/// otherwise.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales_v3(
dnnl_primitive_attr_t attr, int arg, int mask, int ndims,
const dnnl_dims_t group_dims, dnnl_data_type_t data_type,
int is_on_host, dnnl_quantization_mode_t qmode);
/// Sets primitive attributes zero points for primitive operations for a given
/// memory argument. The zero points must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_ZERO_POINTS | arg.

View File

@ -320,6 +320,27 @@ inline dnnl_rounding_mode_t convert_to_c(rounding_mode mode) {
return static_cast<dnnl_rounding_mode_t>(mode);
}
/// Quantization kind
enum class quantization_mode {
/// used for unspecified quantization kind
undef = dnnl_quantization_mode_undef,
/// static quantization mode: quantization parameter is computed
/// ahead of time and passed to oneDNN as an input.
static_sazp = dnnl_quantization_mode_static_sazp,
/// dynamic quantization mode following OCP MX spec: quantization
/// parameter is computed by oneDNN following the OCP MX spec
/// formula and written as an output.
dynamic_mx = dnnl_quantization_mode_dynamic_mx,
};
/// Converts a quantization kind enum value from C++ API to C API type.
///
/// @param mode C++ API quantization kind enum value.
/// @returns Corresponding C API quantization kind enum value.
inline dnnl_quantization_mode_t convert_to_c(quantization_mode qmode) {
return static_cast<dnnl_quantization_mode_t>(qmode);
}
/// Propagation kind.
enum class prop_kind {
/// Undefined propagation kind.
@ -4172,36 +4193,35 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
"could not set scales primitive attribute");
}
/// Sets scaling factors for primitive operations for a given memory
/// argument. The scaling factors must be passed at execution time
/// as an argument with index #DNNL_ARG_ATTR_SCALES | arg.
/// Sets primitive attributes scaling factors for a given memory
/// argument. The scaling factors must be passed at execution time as
/// an argument with index #DNNL_ARG_ATTR_SCALES | arg.
///
/// @note If `is_on_host` is true, sets a single host-side scalar scaling
/// factor for the specified memory argument. The scaling factor should be
/// passed as a host scalar memory object at execution time with index
/// #DNNL_ARG_ATTR_SCALES | arg.
///
/// @sa dnnl_primitive_attr_set_scales_v2
/// @sa dnnl_primitive_attr_set_scales_v3
///
/// @param arg Parameter argument index as passed to the
/// primitive::execute() call.
/// @param mask Scales correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p
/// scales vector. The set i-th bit indicates that a dedicated
/// scale is used for each index along that dimension. Set the
/// mask to 0 to use a common scale for the whole output tensor.
/// primitive execute() call.
/// @param mask Scaling factors correspondence mask that defines the
/// correspondence between the tensor dimensions and the @p scales array.
/// The set i-th bit indicates that a dedicated scaling factor is used for
/// each index along that dimension. Set the mask to 0 to use a common
/// scaling factor for the whole tensor.
/// @param groups Scaling factors correspondence groups that define the
/// correspondence between the tensor dimensions and the scales array.
/// The set i-th dimension indicates a number of groups of scaling
/// factors used for that logical dimension in a memory indicated by @p arg.
/// The group dimensions should only be provided for each logical dimension
/// that has correspondence mask @p mask set.
/// @param data_type Scaling factors data_type.
/// @param is_on_host Indicates whether the scaling factor is a host-side scalar.
/// @param qmode Quantization mode, can be #quantization_mode::static_sazp
/// or #quantization_mode::dynamic_mx
void set_scales(int arg, int mask, const memory::dims &groups,
memory::data_type data_type = memory::data_type::f32,
bool is_on_host = false) {
error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, mask,
bool is_on_host = false,
quantization_mode qmode = quantization_mode::static_sazp) {
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, mask,
(int)groups.size(), groups.data(),
memory::convert_to_c(data_type), is_on_host),
memory::convert_to_c(data_type), is_on_host,
convert_to_c(qmode)),
"could not set scales primitive attribute");
}
@ -4220,8 +4240,9 @@ struct primitive_attr : public handle<dnnl_primitive_attr_t> {
/// @param data_type Scaling factors data_type.
void set_host_scale(
int arg, memory::data_type data_type = memory::data_type::f32) {
error::wrap_c_api(dnnl_primitive_attr_set_scales_v2(get(), arg, 0, 0,
nullptr, memory::convert_to_c(data_type), 1),
error::wrap_c_api(dnnl_primitive_attr_set_scales_v3(get(), arg, 0, 0,
nullptr, memory::convert_to_c(data_type), 1,
dnnl_quantization_mode_static_sazp),
"could not set scales primitive attribute");
}

View File

@ -46,6 +46,7 @@ const char DNNL_API *dnnl_rnn_flags2str(dnnl_rnn_flags_t v);
const char DNNL_API *dnnl_rnn_direction2str(dnnl_rnn_direction_t v);
const char DNNL_API *dnnl_scratchpad_mode2str(dnnl_scratchpad_mode_t v);
const char DNNL_API *dnnl_rounding_mode2str(dnnl_rounding_mode_t v);
const char DNNL_API *dnnl_quantization_mode2str(dnnl_quantization_mode_t v);
const char DNNL_API *dnnl_cpu_isa2str(dnnl_cpu_isa_t v);
const char DNNL_API *dnnl_cpu_isa_hints2str(dnnl_cpu_isa_hints_t v);

View File

@ -2424,6 +2424,21 @@ typedef enum {
dnnl_rounding_mode_stochastic,
} dnnl_rounding_mode_t;
/// Quantization kind
typedef enum {
/// used for unspecified quantization kind
dnnl_quantization_mode_undef,
/// static quantization mode: quantization parameter is computed
/// ahead of time with scale applied after zero-point (\f$x_{f32}
/// = scale * (x_{quant} - zp)\f$) and passed to oneDNN as an
/// input.
dnnl_quantization_mode_static_sazp,
/// dynamic quantization mode following OCP MX spec: quantization
/// parameter is computed by oneDNN following the OCP MX spec
/// formula and written as an output.
dnnl_quantization_mode_dynamic_mx,
} dnnl_quantization_mode_t;
/// @struct dnnl_primitive_attr
/// @brief An opaque structure for primitive descriptor attributes.
///

View File

@ -213,6 +213,7 @@ def sanitize_value(v):
v = v.split("dnnl_accumulation_mode_")[-1]
v = v.split("dnnl_rounding_mode_")[-1]
v = v.split("dnnl_scratchpad_mode_")[-1]
v = v.split("dnnl_quantization_mode_")[-1]
v = v.split("dnnl_")[-1]
return v

View File

@ -211,6 +211,13 @@ const rounding_mode_t environment = dnnl_rounding_mode_environment;
const rounding_mode_t stochastic = dnnl_rounding_mode_stochastic;
} // namespace rounding_mode
using quantization_mode_t = dnnl_quantization_mode_t;
namespace quantization_mode {
const quantization_mode_t undef = dnnl_quantization_mode_undef;
const quantization_mode_t static_sazp = dnnl_quantization_mode_static_sazp;
const quantization_mode_t dynamic_mx = dnnl_quantization_mode_dynamic_mx;
} // namespace quantization_mode
using sparse_encoding_t = dnnl_sparse_encoding_t;
namespace sparse_encoding {
const sparse_encoding_t undef = dnnl_sparse_encoding_undef;

View File

@ -1899,6 +1899,14 @@ const char *dnnl_rounding_mode2str(dnnl_rounding_mode_t v) {
return "unknown rounding_mode";
}
const char *dnnl_quantization_mode2str(dnnl_quantization_mode_t v) {
if (v == dnnl_quantization_mode_undef) return "undef";
if (v == dnnl_quantization_mode_static_sazp) return "static_sazp";
if (v == dnnl_quantization_mode_dynamic_mx) return "dynamic_mx";
assert(!"unknown quantization_mode");
return "unknown quantization_mode";
}
const char *dnnl_cpu_isa2str(dnnl_cpu_isa_t v) {
if (v == dnnl_cpu_isa_default) return "cpu_isa_default";
if (v == dnnl_cpu_isa_sse41) return "cpu_isa_sse41";

View File

@ -577,16 +577,26 @@ status_t dnnl_primitive_attr_set_scales_mask(
status_t dnnl_primitive_attr_set_scales(primitive_attr_t *attr, int arg,
int mask, int group_ndims, const dims_t group_dims,
data_type_t data_type) {
return dnnl_primitive_attr_set_scales_v2(
attr, arg, mask, group_ndims, group_dims, data_type, 0);
return dnnl_primitive_attr_set_scales_v3(attr, arg, mask, group_ndims,
group_dims, data_type, 0, quantization_mode::static_sazp);
}
status_t dnnl_primitive_attr_set_scales_v2(primitive_attr_t *attr, int arg,
int mask, int group_ndims, const dims_t group_dims,
data_type_t data_type, int is_on_host) {
return dnnl_primitive_attr_set_scales_v3(attr, arg, mask, group_ndims,
group_dims, data_type, is_on_host, quantization_mode::static_sazp);
}
status_t dnnl_primitive_attr_set_scales_v3(primitive_attr_t *attr, int arg,
int mask, int group_ndims, const dims_t group_dims,
data_type_t data_type, int is_on_host, quantization_mode_t qmode) {
using namespace data_type;
VCHECK_ATTR(attr, VERBOSE_NULL_ARG);
VCHECK_ATTR(arg >= 0, VERBOSE_BAD_PARAM, "arg");
VCHECK_ATTR(utils::one_of(qmode, quantization_mode::static_sazp,
quantization_mode::dynamic_mx),
VERBOSE_BAD_PARAM, "qmode");
VCHECK_ATTR(
utils::one_of(data_type, f32, bf16, f16, e8m0, f8_e5m2, f8_e4m3),
VERBOSE_INVALID_DATATYPE, "scales");
@ -596,12 +606,12 @@ status_t dnnl_primitive_attr_set_scales_v2(primitive_attr_t *attr, int arg,
if (is_on_host) { // only single value host-side scale is supported
VCHECK_ATTR(mask == 0, VERBOSE_BAD_PARAM, "mask");
VCHECK_ATTR(group_ndims == 0, VERBOSE_BAD_PARAM, "group_ndims");
return attr->scales_.set(arg, 0, data_type, 0, {}, true);
} else {
VCHECK_ATTR(mask >= 0, VERBOSE_BAD_PARAM, "mask");
VCHECK_ATTR(group_ndims >= 0, VERBOSE_BAD_PARAM, "group_ndims");
return attr->scales_.set(arg, mask, data_type, group_ndims, group_dims);
}
return attr->scales_.set(
arg, mask, data_type, group_ndims, group_dims, is_on_host, qmode);
}
status_t dnnl_primitive_attr_set_zero_points_mask(

View File

@ -35,6 +35,7 @@ size_t quant_entry_t::get_hash() const {
seed = primitive_hashing::get_array_hash(
seed, group_dims_, group_ndims_);
seed = hash_combine(seed, is_host_scalar_);
seed = hash_combine(seed, qmode_);
return seed;
}
@ -43,6 +44,7 @@ void quant_entry_t::serialize(serialization_stream_t &sstream) const {
sstream.append(data_type_);
sstream.append_array(group_ndims_, group_dims_);
sstream.append(is_host_scalar_);
sstream.append(qmode_);
}
quant_entry_t quant_entry_t::deserialize(deserializer_t &d) {
@ -53,6 +55,7 @@ quant_entry_t quant_entry_t::deserialize(deserializer_t &d) {
d.pop_array(group_ndims, e.group_dims_);
e.group_ndims_ = static_cast<int>(group_ndims);
d.pop(e.is_host_scalar_);
d.pop(e.qmode_);
return e;
}
@ -67,6 +70,9 @@ std::string quant_entry_t::get_verbose() const {
.append(std::to_string(group_dims_[1]));
}
if (is_host_scalar_) { s.append(":host_scalar"); }
if (qmode_ != quantization_mode::static_sazp) {
s.append(":").append(dnnl_quantization_mode2str(qmode_));
}
return s;
}

View File

@ -57,7 +57,8 @@ struct quant_entry_t : public c_compatible {
return set(mask, data_type, 0, {});
}
status_t set(int mask, data_type_t data_type, int group_ndims,
const dims_t group_dims, bool is_host_scalar = false) {
const dims_t group_dims, bool is_host_scalar = false,
quantization_mode_t qmode = quantization_mode::static_sazp) {
mask_ = mask;
data_type_ = data_type;
group_ndims_ = group_ndims;
@ -65,11 +66,12 @@ struct quant_entry_t : public c_compatible {
utils::array_copy(group_dims_, group_dims, group_ndims_);
}
is_host_scalar_ = is_host_scalar;
qmode_ = qmode;
return status::success;
}
status_t set(const quant_entry_t &other) {
return set(other.mask_, other.data_type_, other.group_ndims_,
other.group_dims_, other.is_host_scalar());
other.group_dims_, other.is_host_scalar(), other.qmode_);
}
quant_entry_t &operator=(const quant_entry_t &rhs) {
@ -89,12 +91,16 @@ struct quant_entry_t : public c_compatible {
dim_t get_group(int d) const {
// If groups were not requested, return `1` for convenience.
if (group_ndims_ == default_quant_entry().group_ndims_) return 1;
// But if they were, any out of bound access would return `0` and likely
// we allow negative indexes to address from last to first
if (d < 0) d += group_ndims_;
// Any out of bound access would return `0` and likely
// lead to a division by zero which is fast to catch.
if (d >= group_ndims_) return 0;
if (d >= group_ndims_ || d < 0) return 0;
return group_dims_[d];
}
bool is_host_scalar() const { return is_host_scalar_; }
quantization_mode_t get_quantization_mode() const { return qmode_; }
bool is_mx() const { return qmode_ == quantization_mode::dynamic_mx; }
status_t get_md(memory_desc_t &out_md, const memory_desc_t &base_md) const {
if (has_default_values()) {
@ -137,6 +143,7 @@ struct quant_entry_t : public c_compatible {
&& IMPLICATION(group_ndims_ > 0,
utils::array_cmp(
group_dims_, rhs.group_dims_, group_ndims_))
&& qmode_ == rhs.qmode_
&& is_host_scalar_ == rhs.is_host_scalar_;
}
@ -157,6 +164,7 @@ private:
int group_ndims_ = 0;
dims_t group_dims_ {};
bool is_host_scalar_ = false;
quantization_mode_t qmode_ = quantization_mode::undef;
};
std::ostream &operator<<(std::ostream &ss, const quant_entry_t &e);
@ -177,10 +185,11 @@ struct quant_entries_t : public c_compatible {
return set(arg, mask, default_data_type_, 0, {});
}
status_t set(int arg, int mask, data_type_t data_type, int group_ndims,
const dims_t group_dims, bool is_host_scalar = false) {
const dims_t group_dims, bool is_host_scalar = false,
quantization_mode_t qmode = quantization_mode::static_sazp) {
if (!check_arg(arg)) return status::invalid_arguments;
CHECK(entries_[arg].set(
mask, data_type, group_ndims, group_dims, is_host_scalar));
CHECK(entries_[arg].set(mask, data_type, group_ndims, group_dims,
is_host_scalar, qmode));
return status::success;
}
// Use this interface with `default_quant_entry` when need to remove a

View File

@ -170,10 +170,15 @@ struct primitive_desc_t : public c_compatible {
}
if (arg & DNNL_ARG_ATTR_SCALES) {
int scale_arg = arg & ~DNNL_ARG_ATTR_SCALES;
return !attr()->scales_.has_default_values(scale_arg)
? arg_usage_t::input
: arg_usage_t::unused;
if (!attr()->scales_.has_default_values(scale_arg)) {
if (attr()->scales_.get(scale_arg).is_mx())
return arg_usage_t::output;
else
return arg_usage_t::input;
} else
return arg_usage_t::unused;
}
if (arg == DNNL_ARG_SCRATCHPAD)
return !is_zero_md(scratchpad_md()) ? arg_usage_t::output
: arg_usage_t::unused;

View File

@ -71,7 +71,8 @@ status_t cvt_primitive_args(const primitive_desc_t *pd, int nargs,
args[arg] = {mem, false};
n_outputs++;
extra_outputs += (arg == DNNL_ARG_SCRATCHPAD)
|| (arg == DNNL_ARG_ATTR_DROPOUT_MASK);
|| (arg == DNNL_ARG_ATTR_DROPOUT_MASK)
|| (arg & DNNL_ARG_ATTR_SCALES);
break;
case primitive_desc_t::arg_usage_t::unused:
VINFO(primitive, exec, check, primitive,

View File

@ -187,6 +187,7 @@ void serialize(serialization_stream_t &sstream, const primitive_attr_t &attr) {
// acc_mode
sstream.append(attr.acc_mode_);
// scales
if (!attr.scales_.has_default_values()) {
sstream.append('s');
attr.scales_.serialize(sstream);