x64: matmul: Enable grouped ZP for per_oc/per_ocic & rework scales

This commit is contained in:
Ovchinnikov Dmitriy
2025-10-07 07:58:07 -07:00
committed by Dmitriy Ovchinnikov
parent 051b020bb1
commit 0ba6d85080
8 changed files with 809 additions and 485 deletions

View File

@ -508,13 +508,13 @@ public:
movups(addr, x);
}
void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
if (is_valid_isa(avx512_core))
vmovdqu16(x, addr);
vmovdqu16(x, op);
else if (is_valid_isa(avx))
vmovups(x, addr);
vmovups(x, op);
else
movups(x, addr);
movups(x, op);
}
void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {

View File

@ -1055,8 +1055,8 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters(
// TODO: review extendable_k_ condition to cover more cases
extendable_k_ = (K % wei_k_blk != 0) && (brgemm_k_elems > wei_k_blk)
&& wei_zp_type == none && !use_buffer_a
&& !packed_sparse_weights && current_lda_ == K;
&& wei_zp_type == none && !apply_scales_in_buffer_b
&& !use_buffer_a && !packed_sparse_weights && current_lda_ == K;
if (extendable_k_) {
if (brgemm_k_elems >= K) {

View File

@ -252,13 +252,20 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
auto check_attr_zero_points = [&]() -> bool {
const auto &zp = attr()->zero_points_;
static const std::vector<int> supported_args {
DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
DNNL_ARG_SRC, DNNL_ARG_DST};
for (int arg : supported_args) {
if (!zp.has_default_values(arg)) {
const int mask = zp.get_mask(arg);
if (mask > 0) return false;
}
}
if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) {
const auto mask = zp.get_mask(DNNL_ARG_WEIGHTS);
const auto kn_mask = wei_qmask_N() + wei_qmask_K();
const bool zp_over_batch = (mask & kn_mask) != mask;
const bool mask_ok = (mask & ~kn_mask) == 0;
return !(zp_over_batch && batch() > 1) && mask_ok;
}
return true;
};
const bool problem_dt_correct = one_of(true, is_f4, is_int8, is_f8, is_bf16,
@ -286,6 +293,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
| primitive_attr_t::skip_mask_t::scales_groups
| primitive_attr_t::skip_mask_t::
zero_points_data_type
| primitive_attr_t::skip_mask_t::zero_points_groups
| primitive_attr_t::skip_mask_t::post_ops
| primitive_attr_t::skip_mask_t::sum_dt
| primitive_attr_t::skip_mask_t::fpmath_mode,
@ -428,11 +436,6 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
auto scratchpad = scratchpad_registry().registrar();
init_scratchpad(scratchpad, bgmmc_);
const auto wei_scale_count = bgmmc_.is_wei_scale_per_k
? (bgmmc_.is_wei_scale_per_n ? N() * K() : K())
: N();
book_precomputed_scales(scratchpad, attr()->scales_, wei_scale_count,
/* scale_adjust_factor = */ 1.f, bgmmc_.req_transpose_scales);
return status::success;
}
@ -1189,9 +1192,9 @@ void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
ctx.zp_a_compensation_result_ptr
= (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
ithr, m_blk_idx);
ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr();
ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
ctx.dynamic_src_ld = brgmm_ctx.get_src_stride();
ctx.zp_b_neg_val_ptr = brgmm_ctx.get_wei_zp_neg_ptr();
for (int gb = 0; gb < gemm_batch_iters; gb++) {
const int k = k_start + gb * bgmmc.K_blk;
@ -1251,46 +1254,29 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
ithr, b_idx, n_blk_idx);
ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
ctx.zp_b_value_ptr = (void *)brgmm_ctx.get_zp_b_val_ptr();
ctx.compensation_ptr
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.dynamic_src_stride = brgmm_ctx.copy_B_wei_stride();
// For best performance, scales should be taken in copy kernels as is.
// The biggest challenge is grouped scales (`wei_scales_gK > 1` == true,
// (`bgmmc.is_wei_scale_per_k` == true) as `gK` starts interfering with
// `K_blk`. In case when `bgmmc.gK_and_K_blk_are_divisible` == true, it's
// easy to process a full `K_blk` piece or a portion of it by applying a
// single scale group. When doing a portion of B, it requires an additional
// loop over a single block, which `wei_scales_n_g` is responsible for.
//
// In other case, weights scales would copy data in a full-size of B
// scratchpad tensor and apply scales per point.
const auto wei_scales_gK = bgmmc.wei_scales_k_group_size;
const auto apply_single_scales_group = bgmmc.is_wei_scale_per_k
&& bgmmc.gK_and_K_blk_are_divisible && wei_scales_gK > 1;
// `div_up` covers the case when K_blk is less than gK.
const auto wei_scales_n_g = apply_single_scales_group
? div_up(bgmmc.K_blk, wei_scales_gK)
: 1;
for_(int gb = 0; gb < gemm_batch; gb++)
for (int k_i = 0; k_i < wei_scales_n_g; k_i++) {
const int k = k_start + gb * bgmmc.K_blk + k_i * wei_scales_gK;
// For the grouped Zero points/scales need to vary k-block size
// For this case need to call copy kernel with unaligned (k, k_iters)
auto call_copy_kernel = [&](int k, int k_iters, int gb,
bool aligned_blocks = false) {
ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n);
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
ithr, k_blk_idx, n_blk_idx, gb, k_i * wei_scales_gK);
ctx.compensation_ptr
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
if (apply_single_scales_group) {
// Update the copy_B K_block to process when apply a single scale.
ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters);
}
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);
// Use k for buffer locating only when the block is unaligned
if (aligned_blocks)
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
ithr, k_blk_idx, n_blk_idx, gb);
else
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_k_ptr(ithr, k);
ctx.current_K_start = k;
ctx.current_K_iters = k_iters;
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(k_iters);
ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr();
ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k);
ctx.zp_b_value_ptr = brgmm_ctx.get_wei_zp_ptr(n, k);
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
&& isa == avx512_core_fp16) {
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
@ -1298,37 +1284,52 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
} else {
(*copy_B_kernel_)(&ctx);
}
}
};
if (!is_K_tail) return;
// grouped zero points &-or scales
if (bgmmc.is_wei_zp_per_k || bgmmc.is_wei_scale_per_k) {
const auto &k_group = bgmmc.is_wei_zp_per_k ? bgmmc.wei_zp_k_gsize
: bgmmc.wei_scales_k_gsize;
const auto brgemm_k_blk = nstl::min(bgmmc.K, bgmmc.K_blk);
const auto adj_k_blk = nstl::min(brgemm_k_blk, k_group);
assert(adj_k_blk > 0);
auto k = k_start;
// is_K_tail behaves incorrectly for the case K < K_blk
// Should be is_K_tail = true && gemm_batch = 0
// Now: is_K_tail = false, gemm_batch = 1.
// Causes Segfault when the `group over k` < K_blk for blocked formats.
const auto work_amount = bgmmc.K < bgmmc.K_blk
? bgmmc.K
: gemm_batch * bgmmc.K_blk
+ is_K_tail * (bgmmc.K % bgmmc.K_blk);
const auto k_end = k_start + work_amount;
// `div_up` covers the case when K_blk is less than gK.
const auto wei_scales_n_g_tail = apply_single_scales_group
? div_up(bgmmc.K % bgmmc.K_blk, wei_scales_gK)
: 1;
for (int k_i = 0; k_i < wei_scales_n_g_tail; k_i++) {
const int k = k_start + gemm_batch * bgmmc.K_blk + k_i * wei_scales_gK;
ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n);
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
ithr, k_blk_idx, n_blk_idx, gemm_batch, k_i * wei_scales_gK);
ctx.compensation_ptr
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
ctx.current_K_start = k;
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
if (apply_single_scales_group) {
// Update the copy_B K_block to process when apply a single scale.
ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters);
// Handle first block
if (k_start % adj_k_blk > 0) {
const auto first_blk_size = adj_k_blk - (k_start % adj_k_blk);
call_copy_kernel(k_start, first_blk_size, 0);
k += first_blk_size;
}
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);
ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr();
ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k);
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
&& isa == avx512_core_fp16) {
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
bgmmc.wei_n_blk * ctx.current_K_iters);
} else {
(*copy_B_kernel_)(&ctx);
// Handle full blocks
for (; (k + adj_k_blk) <= k_end; k += adj_k_blk) {
const auto gb = (k - k_start) / bgmmc.K_blk;
call_copy_kernel(k, adj_k_blk, gb);
}
// Handle last block
if (k_end > k) {
const auto gb = (k - k_start) / bgmmc.K_blk;
call_copy_kernel(k, k_end - k, gb);
}
} else { // Default case with k_blk blocking
for (int gb = 0; gb < gemm_batch; ++gb) {
const auto k = k_start + gb * bgmmc.K_blk;
const auto k_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
call_copy_kernel(k, k_iters, gb, /*aligned_blocks=*/true);
}
if (is_K_tail) {
const auto k = k_start + gemm_batch * bgmmc.K_blk;
const auto k_iters = bgmmc.K % bgmmc.K_blk;
call_copy_kernel(k, k_iters, gemm_batch, /*aligned_blocks=*/true);
}
}
}
@ -1373,7 +1374,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// setup scales / zp pointers
const void *src_zero_points = CTX_IN_MEM(
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
const void *wei_zero_points = CTX_IN_MEM(
wei_zp_ptr_ = CTX_IN_MEM(
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
const void *dst_zero_points = CTX_IN_MEM(
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
@ -1383,18 +1384,18 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
pd->attr()->zero_points_.get_data_type(DNNL_ARG_SRC),
src_zero_points, 0)
: 0;
zero_point_b_val_ = wei_zero_points ? cpu::io::load_int_value(
pd->attr()->zero_points_.get_data_type(
DNNL_ARG_WEIGHTS),
wei_zero_points, 0)
: 0;
zero_point_b_negative_val_ = -zero_point_b_val_;
zero_point_c_val_ = dst_zero_points
? cpu::io::load_int_value(
pd->attr()->zero_points_.get_data_type(DNNL_ARG_DST),
dst_zero_points, 0)
: 0;
wei_zp_neg_val_ = (-1)
* (wei_zp_ptr_ ? cpu::io::load_int_value(
pd->attr()->zero_points_.get_data_type(
DNNL_ARG_WEIGHTS),
wei_zp_ptr_, 0)
: 0);
memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
const auto &bgmmc = pd->get_brgemm_matmul_conf();
@ -1403,10 +1404,6 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
wei_scales_ = CTX_IN_MEM(
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
wei_scales_tr_ = bgmmc_.is_wei_scale_per_k
&& !bgmmc_.gK_and_K_blk_are_divisible
? scratchpad.template get<float>(key_precomputed_scales)
: nullptr;
dst_scales_ = CTX_IN_MEM(
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
dst_scales_inv_ = scratchpad.template get<float>(key_matmul_dst_scales);
@ -1887,17 +1884,36 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// a block of memory per thread.
// `gb` defines an offset to a specific block over K inside a portion of
// blocks.
// `k` defines an offset inside a specific block over K. It is the
// smallest granularity possible. It is used only when wei_decompression
// feature is requested with a single scale over a sub-piece of `K_blk`.
char *get_buf_B_ptr(
int ithr, int k_blk_idx, int n_blk_idx, int gb, int k = 0) const {
char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx, int gb) const {
UNUSED(n_blk_idx);
if (!bgmmc_.use_buffer_b) return nullptr;
int k_blk_local = k_blk_idx % get_K_chunk_size();
return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz
const auto offset = ithr * bgmmc_.buffer_b_per_thread_sz
+ k_blk_local * bgmmc_.buffer_b_k_brg_stride
+ gb * bgmmc_.buffer_b_gb_stride + k * bgmmc_.buffer_b_k_stride;
+ gb * bgmmc_.buffer_b_gb_stride;
return buf_B_ptr_ + offset;
}
/* Returns a pointer to buffer B based on unaligned K inside
* a thread buffer. Used for copy kernels including grouped ZP/Scales.
* For the vnni granularity > 1 it returns a pointer to a start of vnni block.
* Functionality intersects with get_buf_B_ptr(). TODO: make combined solution.
*/
char *get_buf_B_k_ptr(const int ithr, const int k) const {
if (!bgmmc_.use_buffer_b) return nullptr;
const int batch_block_size = bgmmc_.K_blk * bgmmc_.brgemm_batch_size;
const auto batch_blocking = std::div(k, batch_block_size);
const auto k_blk_idx = batch_blocking.quot;
const auto k_blk_local = k_blk_idx % get_K_chunk_size();
const auto k_in_batch = batch_blocking.rem;
auto offset = ithr * bgmmc_.buffer_b_per_thread_sz;
offset += k_blk_local * bgmmc_.buffer_b_k_brg_stride;
// div down to the start of the vnni block
const auto k_outer = (k_in_batch / vnni_factor) * vnni_factor;
offset += k_outer * bgmmc_.buffer_b_k_stride;
return buf_B_ptr_ + offset;
}
char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const {
@ -2049,107 +2065,17 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
// Returns a pointer to the weights scales for the correspondent block based
// on @p n and @p k.
//
// For grouped scales it also prepares a scratchpad buffer when `K_blk` and
// `group_size` don't divide each other.
const void *get_wei_scales_ptr(int n, int k = 0) const {
if (!wei_scales_) return wei_scales_;
const auto gK = bgmmc_.wei_scales_k_group_size;
const auto in_stride_n = bgmmc_.is_wei_scale_per_n ? 1 : 0;
const auto in_stride_k = bgmmc_.is_wei_scale_per_k
? (bgmmc_.is_wei_scale_per_n ? bgmmc_.N : 1)
: 0;
int k_g = gK > 1 ? k / gK : k;
const auto offset = n * in_stride_n + k_g * in_stride_k;
const auto *ptr = reinterpret_cast<const char *const>(wei_scales_)
+ offset * bgmmc_.wei_scales_dt_sz;
if (bgmmc_.gK_and_K_blk_are_divisible || !bgmmc_.is_wei_scale_per_k)
return ptr;
if (bgmmc_.req_transpose_scales) {
// Transpose case covers grouped and non-grouped scenarios.
const auto in_tr_stride_n = bgmmc_.is_wei_scale_per_n
? (bgmmc_.is_wei_scale_per_k ? bgmmc_.K : 1)
: 0;
const auto in_tr_stride_k = bgmmc_.is_wei_scale_per_k ? 1 : 0;
const auto offset_tr = k * in_tr_stride_k + n * in_tr_stride_n;
auto *ptr_tr = reinterpret_cast<char *const>(wei_scales_tr_)
+ offset_tr * bgmmc_.wei_scales_dt_sz;
// `kk_offset` covers a situation when
// `!bgmmc_.gK_and_K_blk_are_divisible`. Need to process the start
// of the next K block in a special way.
//
// Process it separately for simpler main body code.
const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0;
if (kk_offset) {
for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk);
nn++)
for (int kk = 0; kk < kk_offset; kk++) {
const auto in_idx = nn * in_stride_n;
const auto out_idx = nn * bgmmc_.K + kk;
const float wei_scales_val = cpu::io::load_float_value(
bgmmc_.wei_scales_dt, ptr, in_idx);
cpu::io::store_float_value(bgmmc_.wei_scales_dt,
wei_scales_val, ptr_tr, out_idx);
}
}
for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++)
for (int kk = 0;
kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset;
kk++) {
const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k;
const auto out_idx = nn * bgmmc_.K + kk;
const auto ptr_kk_offset = bool(kk_offset) * in_stride_k
* bgmmc_.wei_scales_dt_sz;
const auto ptr_tr_kk_offset
= kk_offset * in_tr_stride_k * bgmmc_.wei_scales_dt_sz;
const float wei_scales_val = cpu::io::load_float_value(
bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx);
cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val,
ptr_tr + ptr_tr_kk_offset, out_idx);
}
return ptr_tr;
} else {
const auto offset_non_g = n * in_stride_n + k * in_stride_k;
auto *ptr_non_g = reinterpret_cast<char *const>(wei_scales_tr_)
+ offset_non_g * bgmmc_.wei_scales_dt_sz;
const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0;
if (kk_offset) {
for_(int kk = 0; kk < kk_offset; kk++)
for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk);
nn++) {
const auto in_idx = nn * in_stride_n;
const auto out_idx = nn * in_stride_n + kk * in_stride_k;
const float wei_scales_val = cpu::io::load_float_value(
bgmmc_.wei_scales_dt, ptr, in_idx);
cpu::io::store_float_value(bgmmc_.wei_scales_dt,
wei_scales_val, ptr_non_g, out_idx);
}
}
for_(int kk = 0;
kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset;
kk++)
for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++) {
const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k;
const auto out_idx = nn * in_stride_n + kk * in_stride_k;
const auto ptr_kk_offset = bool(kk_offset) * in_stride_k
* bgmmc_.wei_scales_dt_sz;
const auto ptr_non_g_kk_offset
= kk_offset * in_stride_k * bgmmc_.wei_scales_dt_sz;
const float wei_scales_val = cpu::io::load_float_value(
bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx);
cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val,
ptr_non_g + ptr_non_g_kk_offset, out_idx);
}
return ptr_non_g;
if (bgmmc_.is_wei_scale_common) return wei_scales_;
auto offset = n;
if (bgmmc_.is_wei_scale_per_k) {
const auto &k_group_sz = bgmmc_.wei_scales_k_gsize;
const auto k_idx = k / k_group_sz;
offset += k_idx * bgmmc_.N;
}
offset = offset * bgmmc_.wei_scales_dt_sz;
return ((char *)wei_scales_ + offset);
}
const void *get_dst_scales_ptr() const { return dst_scales_; }
@ -2167,11 +2093,28 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
return &zero_point_a_negative_val_;
}
const int32_t *get_zp_b_neg_val_ptr() const {
return &zero_point_b_negative_val_;
}
const void *get_wei_zp_neg_ptr() const { return &wei_zp_neg_val_; }
const int32_t *get_zp_b_val_ptr() const { return &zero_point_b_val_; }
const void *get_wei_zp_ptr(int n, int k = 0) const {
if (!bgmmc_.has_zero_point_b) return nullptr;
if (bgmmc_.is_wei_zp_common)
return wei_zp_ptr_; // single zero point value
// Locate the group based on (n,k)
auto offset = n;
if (bgmmc_.is_wei_zp_per_k) {
const auto &k_group_sz = bgmmc_.wei_zp_k_gsize;
const auto k_idx = k / k_group_sz;
offset += k_idx * bgmmc_.N;
}
const auto dt_sz = types::data_type_size(bgmmc_.wei_zp_dt);
const auto elems_per_byte
= one_of(bgmmc_.wei_zp_dt, data_type::s4, data_type::u4) ? 2
: 1;
offset = offset * dt_sz / elems_per_byte;
return (char *)wei_zp_ptr_ + offset;
}
const int32_t *get_zp_ab_mixed_comp_ptr() const {
return &zero_point_mixed_ab_compensation_component_;
@ -2457,6 +2400,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
bool packed_sparse_weights() const { return bgmmc_.packed_sparse_weights; }
int get_current_K_pad(int current_K_iters) const {
if (bgmmc_.is_wei_zp_per_k || bgmmc_.is_wei_scale_per_k) return 0;
if (current_K_iters % bgmmc_.wei_k_blk == 0) return 0;
return (bgmmc_.extendable_k || bgmmc_.use_fused_copy_a)
? bgmmc_.wei_k_blk
@ -2514,14 +2458,7 @@ private:
const char *bias_ptr_;
const void *src_scales_;
const void *wei_scales_;
// This pointer is coming from scratchpad and is needed to expand (K/g)xN
// scales to KxN scales as copy_B kernels rely on the full register scales
// for weights decompression feature in case when K_blk is not divisible by
// the scales group size or vice versa.
// TODO: implement the logic at calling copy routines spot to handle the
// unsupported scenario mentioned above to avoid the need in this pointer
// and the overhead around filling that big buffer.
void *wei_scales_tr_;
const void *dst_scales_;
const void *dst_scales_inv_;
int32_t *s8s8_compensation_ptr_;
@ -2531,10 +2468,10 @@ private:
int32_t *reorder_zp_a_comp_ptr_;
int32_t zero_point_a_negative_val_;
int32_t zero_point_b_val_;
int32_t zero_point_b_negative_val_;
int32_t zero_point_mixed_ab_compensation_component_;
int32_t zero_point_c_val_;
int32_t wei_zp_neg_val_;
const void *wei_zp_ptr_;
std::vector<const void *> post_ops_binary_rhs_arg_vec_;
int base_brg_ker_idx_;

File diff suppressed because it is too large Load Diff

View File

@ -59,7 +59,7 @@ struct jit_brgemm_matmul_copy_a_t {
const void *tr_src;
const void *zp_b_compensation_buffer_ptr;
const void *zp_a_compensation_result_ptr;
const void *zp_b_neg_value_ptr;
const void *zp_b_neg_val_ptr;
const void *zp_ab_comp_ptr;
dim_t current_K_start;

View File

@ -1296,6 +1296,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.dst_dt = dst_d.data_type();
bgmmc.wei_dt = weights_d.data_type();
bgmmc.orig_wei_dt = weights_d.data_type();
bgmmc.wei_zp_dt = attr.zero_points_.get(DNNL_ARG_WEIGHTS).get_data_type();
bgmmc.with_reduce = mmd.reduce_desc.format_kind != format_kind::undef;
bgmmc.reduce_dt
@ -1396,19 +1397,19 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.with_src_scales = !src_scales.has_default_values();
bgmmc.with_wei_scales = !wei_scales.has_default_values();
if (bgmmc.with_wei_scales) {
const auto wei_qmask_N = 1 << (bgmmc.ndims - 1);
const auto wei_qmask_K = 1 << (bgmmc.ndims - 2);
bgmmc.is_wei_scale_per_k = wei_scales.get_mask() & wei_qmask_K;
bgmmc.is_wei_scale_per_n = wei_scales.get_mask() & wei_qmask_N;
const auto &wei_scale_mask = wei_scales.get_mask();
bgmmc.is_wei_scale_common = wei_scale_mask == 0;
bgmmc.is_wei_scale_per_k = wei_scale_mask & 1 << (bgmmc.ndims - 2);
bgmmc.is_wei_scale_per_n = wei_scale_mask & 1 << (bgmmc.ndims - 1);
bgmmc.apply_scales_in_buffer_b = bgmmc.is_wei_scale_per_k
&& bgmmc.with_wei_decompression && bgmmc.N * bgmmc.K != 1;
bgmmc.wei_scales_dt = wei_scales.get_data_type();
bgmmc.wei_scales_dt_sz = types::data_type_size(bgmmc.wei_scales_dt);
bgmmc.wei_scales_k_group_size = wei_scales.get_group(0);
bgmmc.wei_scales_k_gsize = wei_scales.get_group(0);
// only common and per-oc-channel scales are supported
// only per-ic-channel scales is supprted with weight decompression
VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_wei_scale_per_n
VCONDCHECK_BG(bgmmc.is_wei_scale_common || bgmmc.is_wei_scale_per_n
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
bgmmc.with_wei_decompression),
VERBOSE_UNSUPPORTED_SCALES_CFG);
@ -1420,6 +1421,28 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0),
VERBOSE_UNSUPPORTED_SCALES_CFG);
const auto &wei_zp = attr.zero_points_.get(DNNL_ARG_WEIGHTS);
const auto has_wei_zp = !wei_zp.has_default_values();
if (has_wei_zp) {
const auto wei_zp_mask = wei_zp.get_mask();
bgmmc.is_wei_zp_common = wei_zp_mask == 0;
bgmmc.is_wei_zp_per_k = wei_zp_mask & (1 << (bgmmc.ndims - 2));
bgmmc.is_wei_zp_per_n = wei_zp_mask & (1 << (bgmmc.ndims - 1));
bgmmc.wei_zp_dt = wei_zp.get_data_type();
bgmmc.wei_zp_k_gsize = wei_zp.get_group(0);
VCONDCHECK_BG(wei_zp_mask == 0 || bgmmc.is_wei_zp_per_k
|| bgmmc.is_wei_zp_per_n,
VERBOSE_UNSUPPORTED_ZP_CFG);
// Check if K groups for scales and for zero points are identical
VCONDCHECK_BG(
IMPLICATION(bgmmc.is_wei_zp_per_k && bgmmc.is_wei_scale_per_k,
bgmmc.wei_zp_k_gsize == bgmmc.wei_scales_k_gsize),
VERBOSE_UNSUPPORTED_ZP_CFG);
}
const auto &p = attr.post_ops_;
bgmmc.with_sum = p.find(primitive_kind::sum) != -1;
const int eltwise_ind = p.find(primitive_kind::eltwise);
@ -1532,9 +1555,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
bgmmc.transposed_B = bm_conf_utils.check_is_transposed(bgmmc.wei_tag)
|| bgmmc.wei_tag == adbc;
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
bgmmc.req_transpose_scales = bgmmc.apply_scales_in_buffer_b
&& bgmmc.is_wei_scale_per_k && bgmmc.is_wei_scale_per_n
&& bgmmc.transposed_B;
if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
&& is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) {
@ -1719,14 +1739,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
= bm_conf_utils.wei_down_convert_to_vnni();
}
// This setting must be updated post blocking as it has a dependency on
// `bgmmc.K_blk`. See `gK_and_K_blk_are_divisible` comment.
if (bgmmc.is_wei_scale_per_k) {
const auto gK = bgmmc.wei_scales_k_group_size;
bgmmc.gK_and_K_blk_are_divisible = gK > 1
&& ((bgmmc.K_blk % gK == 0) || (gK % bgmmc.K_blk == 0));
}
VCHECK_BG(bm_conf_utils.set_B_flags(weights_md), VERBOSE_BLOCKING_FAIL, "");
bgmmc.M_tail = bgmmc.is_runtime_M ? 0 : bgmmc.M % bgmmc.M_blk;

View File

@ -136,6 +136,7 @@ struct brgemm_matmul_conf_t {
data_type_t reduce_dt;
data_type_t orig_src_dt;
data_type_t orig_wei_dt;
int nthr;
int nthr_k = 1, nthr_m = 1, nthr_n = 1, nthr_b = 1;
@ -227,24 +228,27 @@ struct brgemm_matmul_conf_t {
bool is_runtime_M = false;
bool is_runtime_N = false;
bool is_runtime_K = false;
bool extendable_k = false;
bool is_src_batch_layout_trivial = false;
bool is_wei_batch_layout_trivial = false;
bool is_dst_batch_layout_trivial = false;
// Attributes related to quantization
// Scales
bool apply_scales_in_buffer_b = false;
size_t wei_scales_dt_sz = 0;
bool is_wei_scale_per_n = false;
bool is_wei_scale_per_k = false;
bool req_transpose_scales = false;
bool apply_scales_in_buffer_b = false;
// For generic cases, when groups are selected the way they can't divide a
// K_blk in equal pieces, it gets really hard to call a kernel with a
// single "per_N line" of scales. In this case weights will be copied
// to a larger memory buffer and used like a full tensor.
// TODO: convert to a method. State must be set in a specific place of the
// initialization as relies on blocking.
bool gK_and_K_blk_are_divisible = false;
size_t wei_scales_dt_sz = 0;
dim_t wei_scales_k_group_size = 0;
bool is_wei_scale_common = false;
dim_t wei_scales_k_gsize = 0;
data_type_t wei_scales_dt = data_type::undef;
bool extendable_k = false;
// Zero points
dim_t wei_zp_k_gsize = 0;
bool is_wei_zp_per_k = false;
bool is_wei_zp_per_n = false;
bool is_wei_zp_common = false;
data_type_t wei_zp_dt = data_type::undef;
bool is_gemv = false;

View File

@ -319,8 +319,7 @@
--dt=f32:s4:f32,f32:u4:f32
--wtag=any,ab,ba
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
--attr-zero-points=,wei:common:2
# ,wei:per_oc:s4:4096x1 # PER_OC skips for all matmul implemetations
--attr-zero-points=,wei:common:2,wei:per_oc:s4
--attr-fpmath=strict:true
1x4096:4096x4096
@ -328,7 +327,32 @@
--dt=f32:s8:f32,f32:u8:f32
--wtag=any,ab,ba
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
--attr-zero-points=,wei:common:2
# ,wei:per_ocic:s8:32x1 # PER_OC skips for all matmul implementations
--attr-zero-points=,wei:common:2,wei:per_ocic:s8:32x1
--attr-fpmath=strict:true
1x4096:4096x4096
## Additional grouped scales/ZP testing
--reset
--dt=f16:s4:f16
--wtag=abc,acb
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
--attr-fpmath=f16:true
12x4x576:12x576x192
--reset
--dt=bf16:s4:bf16
--wtag=abc,acb
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
--attr-fpmath=bf16:true
12x4x576:12x576x192
--reset
--dt=f32:s4:f32
--wtag=abc,acb
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
--attr-fpmath=strict:true
12x4x576:12x576x192