diff --git a/src/cpu/x64/jit_generator.hpp b/src/cpu/x64/jit_generator.hpp index 3c4c4d9c63..84144036df 100644 --- a/src/cpu/x64/jit_generator.hpp +++ b/src/cpu/x64/jit_generator.hpp @@ -508,13 +508,13 @@ public: movups(addr, x); } - void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Address &addr) { + void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Operand &op) { if (is_valid_isa(avx512_core)) - vmovdqu16(x, addr); + vmovdqu16(x, op); else if (is_valid_isa(avx)) - vmovups(x, addr); + vmovups(x, op); else - movups(x, addr); + movups(x, op); } void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) { diff --git a/src/cpu/x64/matmul/amx_blocking_heuristics.cpp b/src/cpu/x64/matmul/amx_blocking_heuristics.cpp index 26ee12e40c..e95b2fe755 100644 --- a/src/cpu/x64/matmul/amx_blocking_heuristics.cpp +++ b/src/cpu/x64/matmul/amx_blocking_heuristics.cpp @@ -1055,8 +1055,8 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters( // TODO: review extendable_k_ condition to cover more cases extendable_k_ = (K % wei_k_blk != 0) && (brgemm_k_elems > wei_k_blk) - && wei_zp_type == none && !use_buffer_a - && !packed_sparse_weights && current_lda_ == K; + && wei_zp_type == none && !apply_scales_in_buffer_b + && !use_buffer_a && !packed_sparse_weights && current_lda_ == K; if (extendable_k_) { if (brgemm_k_elems >= K) { diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 762bf872d2..37fc5cf1fc 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -252,13 +252,20 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { auto check_attr_zero_points = [&]() -> bool { const auto &zp = attr()->zero_points_; static const std::vector supported_args { - DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; + DNNL_ARG_SRC, DNNL_ARG_DST}; for (int arg : supported_args) { if (!zp.has_default_values(arg)) { const int mask = zp.get_mask(arg); if (mask > 0) return false; } } + if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) { + const auto mask = zp.get_mask(DNNL_ARG_WEIGHTS); + const auto kn_mask = wei_qmask_N() + wei_qmask_K(); + const bool zp_over_batch = (mask & kn_mask) != mask; + const bool mask_ok = (mask & ~kn_mask) == 0; + return !(zp_over_batch && batch() > 1) && mask_ok; + } return true; }; const bool problem_dt_correct = one_of(true, is_f4, is_int8, is_f8, is_bf16, @@ -286,6 +293,7 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { | primitive_attr_t::skip_mask_t::scales_groups | primitive_attr_t::skip_mask_t:: zero_points_data_type + | primitive_attr_t::skip_mask_t::zero_points_groups | primitive_attr_t::skip_mask_t::post_ops | primitive_attr_t::skip_mask_t::sum_dt | primitive_attr_t::skip_mask_t::fpmath_mode, @@ -428,11 +436,6 @@ status_t brgemm_matmul_t::pd_t::init(engine_t *engine) { auto scratchpad = scratchpad_registry().registrar(); init_scratchpad(scratchpad, bgmmc_); - const auto wei_scale_count = bgmmc_.is_wei_scale_per_k - ? (bgmmc_.is_wei_scale_per_n ? N() * K() : K()) - : N(); - book_precomputed_scales(scratchpad, attr()->scales_, wei_scale_count, - /* scale_adjust_factor = */ 1.f, bgmmc_.req_transpose_scales); return status::success; } @@ -1189,9 +1192,9 @@ void brgemm_matmul_t::copy_a_chunk_in_buffer( ctx.zp_a_compensation_result_ptr = (void *)brgmm_ctx.get_zp_b_compensation_result_ptr( ithr, m_blk_idx); - ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr(); ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr(); ctx.dynamic_src_ld = brgmm_ctx.get_src_stride(); + ctx.zp_b_neg_val_ptr = brgmm_ctx.get_wei_zp_neg_ptr(); for (int gb = 0; gb < gemm_batch_iters; gb++) { const int k = k_start + gb * bgmmc.K_blk; @@ -1251,46 +1254,29 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr( ithr, b_idx, n_blk_idx); ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr(); - ctx.zp_b_value_ptr = (void *)brgmm_ctx.get_zp_b_val_ptr(); + ctx.compensation_ptr + = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); + ctx.dynamic_src_stride = brgmm_ctx.copy_B_wei_stride(); - // For best performance, scales should be taken in copy kernels as is. - // The biggest challenge is grouped scales (`wei_scales_gK > 1` == true, - // (`bgmmc.is_wei_scale_per_k` == true) as `gK` starts interfering with - // `K_blk`. In case when `bgmmc.gK_and_K_blk_are_divisible` == true, it's - // easy to process a full `K_blk` piece or a portion of it by applying a - // single scale group. When doing a portion of B, it requires an additional - // loop over a single block, which `wei_scales_n_g` is responsible for. - // - // In other case, weights scales would copy data in a full-size of B - // scratchpad tensor and apply scales per point. - const auto wei_scales_gK = bgmmc.wei_scales_k_group_size; - const auto apply_single_scales_group = bgmmc.is_wei_scale_per_k - && bgmmc.gK_and_K_blk_are_divisible && wei_scales_gK > 1; - - // `div_up` covers the case when K_blk is less than gK. - const auto wei_scales_n_g = apply_single_scales_group - ? div_up(bgmmc.K_blk, wei_scales_gK) - : 1; - - for_(int gb = 0; gb < gemm_batch; gb++) - for (int k_i = 0; k_i < wei_scales_n_g; k_i++) { - const int k = k_start + gb * bgmmc.K_blk + k_i * wei_scales_gK; + // For the grouped Zero points/scales need to vary k-block size + // For this case need to call copy kernel with unaligned (k, k_iters) + auto call_copy_kernel = [&](int k, int k_iters, int gb, + bool aligned_blocks = false) { ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n); - ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr( - ithr, k_blk_idx, n_blk_idx, gb, k_i * wei_scales_gK); - ctx.compensation_ptr - = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); - ctx.current_K_start = k; - ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K); - if (apply_single_scales_group) { - // Update the copy_B K_block to process when apply a single scale. - ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters); - } - ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters); + // Use k for buffer locating only when the block is unaligned + if (aligned_blocks) + ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr( + ithr, k_blk_idx, n_blk_idx, gb); + else + ctx.tr_src = (void *)brgmm_ctx.get_buf_B_k_ptr(ithr, k); + ctx.current_K_start = k; + ctx.current_K_iters = k_iters; + ctx.current_K_pad = brgmm_ctx.get_current_K_pad(k_iters); ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr(); ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k); + ctx.zp_b_value_ptr = brgmm_ctx.get_wei_zp_ptr(n, k); if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei && isa == avx512_core_fp16) { cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src, @@ -1298,37 +1284,52 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( } else { (*copy_B_kernel_)(&ctx); } - } + }; - if (!is_K_tail) return; + // grouped zero points &-or scales + if (bgmmc.is_wei_zp_per_k || bgmmc.is_wei_scale_per_k) { + const auto &k_group = bgmmc.is_wei_zp_per_k ? bgmmc.wei_zp_k_gsize + : bgmmc.wei_scales_k_gsize; + const auto brgemm_k_blk = nstl::min(bgmmc.K, bgmmc.K_blk); + const auto adj_k_blk = nstl::min(brgemm_k_blk, k_group); + assert(adj_k_blk > 0); + auto k = k_start; + // is_K_tail behaves incorrectly for the case K < K_blk + // Should be is_K_tail = true && gemm_batch = 0 + // Now: is_K_tail = false, gemm_batch = 1. + // Causes Segfault when the `group over k` < K_blk for blocked formats. + const auto work_amount = bgmmc.K < bgmmc.K_blk + ? bgmmc.K + : gemm_batch * bgmmc.K_blk + + is_K_tail * (bgmmc.K % bgmmc.K_blk); + const auto k_end = k_start + work_amount; - // `div_up` covers the case when K_blk is less than gK. - const auto wei_scales_n_g_tail = apply_single_scales_group - ? div_up(bgmmc.K % bgmmc.K_blk, wei_scales_gK) - : 1; - - for (int k_i = 0; k_i < wei_scales_n_g_tail; k_i++) { - const int k = k_start + gemm_batch * bgmmc.K_blk + k_i * wei_scales_gK; - ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n); - ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr( - ithr, k_blk_idx, n_blk_idx, gemm_batch, k_i * wei_scales_gK); - ctx.compensation_ptr - = (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx); - ctx.current_K_start = k; - ctx.current_K_iters = bgmmc.K % bgmmc.K_blk; - if (apply_single_scales_group) { - // Update the copy_B K_block to process when apply a single scale. - ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters); + // Handle first block + if (k_start % adj_k_blk > 0) { + const auto first_blk_size = adj_k_blk - (k_start % adj_k_blk); + call_copy_kernel(k_start, first_blk_size, 0); + k += first_blk_size; } - ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters); - ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr(); - ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k); - if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei - && isa == avx512_core_fp16) { - cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src, - bgmmc.wei_n_blk * ctx.current_K_iters); - } else { - (*copy_B_kernel_)(&ctx); + // Handle full blocks + for (; (k + adj_k_blk) <= k_end; k += adj_k_blk) { + const auto gb = (k - k_start) / bgmmc.K_blk; + call_copy_kernel(k, adj_k_blk, gb); + } + // Handle last block + if (k_end > k) { + const auto gb = (k - k_start) / bgmmc.K_blk; + call_copy_kernel(k, k_end - k, gb); + } + } else { // Default case with k_blk blocking + for (int gb = 0; gb < gemm_batch; ++gb) { + const auto k = k_start + gb * bgmmc.K_blk; + const auto k_iters = nstl::min(bgmmc.K_blk, bgmmc.K); + call_copy_kernel(k, k_iters, gb, /*aligned_blocks=*/true); + } + if (is_K_tail) { + const auto k = k_start + gemm_batch * bgmmc.K_blk; + const auto k_iters = bgmmc.K % bgmmc.K_blk; + call_copy_kernel(k, k_iters, gemm_batch, /*aligned_blocks=*/true); } } } @@ -1373,7 +1374,7 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // setup scales / zp pointers const void *src_zero_points = CTX_IN_MEM( const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC); - const void *wei_zero_points = CTX_IN_MEM( + wei_zp_ptr_ = CTX_IN_MEM( const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS); const void *dst_zero_points = CTX_IN_MEM( const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST); @@ -1383,18 +1384,18 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { pd->attr()->zero_points_.get_data_type(DNNL_ARG_SRC), src_zero_points, 0) : 0; - zero_point_b_val_ = wei_zero_points ? cpu::io::load_int_value( - pd->attr()->zero_points_.get_data_type( - DNNL_ARG_WEIGHTS), - wei_zero_points, 0) - : 0; - zero_point_b_negative_val_ = -zero_point_b_val_; zero_point_c_val_ = dst_zero_points ? cpu::io::load_int_value( pd->attr()->zero_points_.get_data_type(DNNL_ARG_DST), dst_zero_points, 0) : 0; + wei_zp_neg_val_ = (-1) + * (wei_zp_ptr_ ? cpu::io::load_int_value( + pd->attr()->zero_points_.get_data_type( + DNNL_ARG_WEIGHTS), + wei_zp_ptr_, 0) + : 0); memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor(); const auto &bgmmc = pd->get_brgemm_matmul_conf(); @@ -1403,10 +1404,6 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC); wei_scales_ = CTX_IN_MEM( const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS); - wei_scales_tr_ = bgmmc_.is_wei_scale_per_k - && !bgmmc_.gK_and_K_blk_are_divisible - ? scratchpad.template get(key_precomputed_scales) - : nullptr; dst_scales_ = CTX_IN_MEM( const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST); dst_scales_inv_ = scratchpad.template get(key_matmul_dst_scales); @@ -1887,17 +1884,36 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // a block of memory per thread. // `gb` defines an offset to a specific block over K inside a portion of // blocks. - // `k` defines an offset inside a specific block over K. It is the - // smallest granularity possible. It is used only when wei_decompression - // feature is requested with a single scale over a sub-piece of `K_blk`. - char *get_buf_B_ptr( - int ithr, int k_blk_idx, int n_blk_idx, int gb, int k = 0) const { + char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx, int gb) const { UNUSED(n_blk_idx); if (!bgmmc_.use_buffer_b) return nullptr; int k_blk_local = k_blk_idx % get_K_chunk_size(); - return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz + const auto offset = ithr * bgmmc_.buffer_b_per_thread_sz + k_blk_local * bgmmc_.buffer_b_k_brg_stride - + gb * bgmmc_.buffer_b_gb_stride + k * bgmmc_.buffer_b_k_stride; + + gb * bgmmc_.buffer_b_gb_stride; + return buf_B_ptr_ + offset; + } + + /* Returns a pointer to buffer B based on unaligned K inside + * a thread buffer. Used for copy kernels including grouped ZP/Scales. + * For the vnni granularity > 1 it returns a pointer to a start of vnni block. + * Functionality intersects with get_buf_B_ptr(). TODO: make combined solution. + */ + char *get_buf_B_k_ptr(const int ithr, const int k) const { + if (!bgmmc_.use_buffer_b) return nullptr; + + const int batch_block_size = bgmmc_.K_blk * bgmmc_.brgemm_batch_size; + const auto batch_blocking = std::div(k, batch_block_size); + const auto k_blk_idx = batch_blocking.quot; + const auto k_blk_local = k_blk_idx % get_K_chunk_size(); + const auto k_in_batch = batch_blocking.rem; + + auto offset = ithr * bgmmc_.buffer_b_per_thread_sz; + offset += k_blk_local * bgmmc_.buffer_b_k_brg_stride; + // div down to the start of the vnni block + const auto k_outer = (k_in_batch / vnni_factor) * vnni_factor; + offset += k_outer * bgmmc_.buffer_b_k_stride; + return buf_B_ptr_ + offset; } char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const { @@ -2049,107 +2065,17 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // Returns a pointer to the weights scales for the correspondent block based // on @p n and @p k. - // - // For grouped scales it also prepares a scratchpad buffer when `K_blk` and - // `group_size` don't divide each other. const void *get_wei_scales_ptr(int n, int k = 0) const { - if (!wei_scales_) return wei_scales_; - - const auto gK = bgmmc_.wei_scales_k_group_size; - const auto in_stride_n = bgmmc_.is_wei_scale_per_n ? 1 : 0; - const auto in_stride_k = bgmmc_.is_wei_scale_per_k - ? (bgmmc_.is_wei_scale_per_n ? bgmmc_.N : 1) - : 0; - int k_g = gK > 1 ? k / gK : k; - const auto offset = n * in_stride_n + k_g * in_stride_k; - const auto *ptr = reinterpret_cast(wei_scales_) - + offset * bgmmc_.wei_scales_dt_sz; - if (bgmmc_.gK_and_K_blk_are_divisible || !bgmmc_.is_wei_scale_per_k) - return ptr; - - if (bgmmc_.req_transpose_scales) { - // Transpose case covers grouped and non-grouped scenarios. - const auto in_tr_stride_n = bgmmc_.is_wei_scale_per_n - ? (bgmmc_.is_wei_scale_per_k ? bgmmc_.K : 1) - : 0; - const auto in_tr_stride_k = bgmmc_.is_wei_scale_per_k ? 1 : 0; - const auto offset_tr = k * in_tr_stride_k + n * in_tr_stride_n; - auto *ptr_tr = reinterpret_cast(wei_scales_tr_) - + offset_tr * bgmmc_.wei_scales_dt_sz; - - // `kk_offset` covers a situation when - // `!bgmmc_.gK_and_K_blk_are_divisible`. Need to process the start - // of the next K block in a special way. - // - // Process it separately for simpler main body code. - const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0; - if (kk_offset) { - for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); - nn++) - for (int kk = 0; kk < kk_offset; kk++) { - const auto in_idx = nn * in_stride_n; - const auto out_idx = nn * bgmmc_.K + kk; - const float wei_scales_val = cpu::io::load_float_value( - bgmmc_.wei_scales_dt, ptr, in_idx); - cpu::io::store_float_value(bgmmc_.wei_scales_dt, - wei_scales_val, ptr_tr, out_idx); - } - } - - for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++) - for (int kk = 0; - kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset; - kk++) { - const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k; - const auto out_idx = nn * bgmmc_.K + kk; - const auto ptr_kk_offset = bool(kk_offset) * in_stride_k - * bgmmc_.wei_scales_dt_sz; - const auto ptr_tr_kk_offset - = kk_offset * in_tr_stride_k * bgmmc_.wei_scales_dt_sz; - const float wei_scales_val = cpu::io::load_float_value( - bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx); - cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val, - ptr_tr + ptr_tr_kk_offset, out_idx); - } - - return ptr_tr; - } else { - const auto offset_non_g = n * in_stride_n + k * in_stride_k; - auto *ptr_non_g = reinterpret_cast(wei_scales_tr_) - + offset_non_g * bgmmc_.wei_scales_dt_sz; - - const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0; - if (kk_offset) { - for_(int kk = 0; kk < kk_offset; kk++) - for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); - nn++) { - const auto in_idx = nn * in_stride_n; - const auto out_idx = nn * in_stride_n + kk * in_stride_k; - const float wei_scales_val = cpu::io::load_float_value( - bgmmc_.wei_scales_dt, ptr, in_idx); - cpu::io::store_float_value(bgmmc_.wei_scales_dt, - wei_scales_val, ptr_non_g, out_idx); - } - } - - for_(int kk = 0; - kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset; - kk++) - for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++) { - const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k; - const auto out_idx = nn * in_stride_n + kk * in_stride_k; - const auto ptr_kk_offset = bool(kk_offset) * in_stride_k - * bgmmc_.wei_scales_dt_sz; - const auto ptr_non_g_kk_offset - = kk_offset * in_stride_k * bgmmc_.wei_scales_dt_sz; - const float wei_scales_val = cpu::io::load_float_value( - bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx); - cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val, - ptr_non_g + ptr_non_g_kk_offset, out_idx); - } - - return ptr_non_g; + if (bgmmc_.is_wei_scale_common) return wei_scales_; + auto offset = n; + if (bgmmc_.is_wei_scale_per_k) { + const auto &k_group_sz = bgmmc_.wei_scales_k_gsize; + const auto k_idx = k / k_group_sz; + offset += k_idx * bgmmc_.N; } + + offset = offset * bgmmc_.wei_scales_dt_sz; + return ((char *)wei_scales_ + offset); } const void *get_dst_scales_ptr() const { return dst_scales_; } @@ -2167,11 +2093,28 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { return &zero_point_a_negative_val_; } - const int32_t *get_zp_b_neg_val_ptr() const { - return &zero_point_b_negative_val_; - } + const void *get_wei_zp_neg_ptr() const { return &wei_zp_neg_val_; } - const int32_t *get_zp_b_val_ptr() const { return &zero_point_b_val_; } + const void *get_wei_zp_ptr(int n, int k = 0) const { + if (!bgmmc_.has_zero_point_b) return nullptr; + if (bgmmc_.is_wei_zp_common) + return wei_zp_ptr_; // single zero point value + // Locate the group based on (n,k) + auto offset = n; + + if (bgmmc_.is_wei_zp_per_k) { + const auto &k_group_sz = bgmmc_.wei_zp_k_gsize; + const auto k_idx = k / k_group_sz; + offset += k_idx * bgmmc_.N; + } + + const auto dt_sz = types::data_type_size(bgmmc_.wei_zp_dt); + const auto elems_per_byte + = one_of(bgmmc_.wei_zp_dt, data_type::s4, data_type::u4) ? 2 + : 1; + offset = offset * dt_sz / elems_per_byte; + return (char *)wei_zp_ptr_ + offset; + } const int32_t *get_zp_ab_mixed_comp_ptr() const { return &zero_point_mixed_ab_compensation_component_; @@ -2457,6 +2400,7 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { bool packed_sparse_weights() const { return bgmmc_.packed_sparse_weights; } int get_current_K_pad(int current_K_iters) const { + if (bgmmc_.is_wei_zp_per_k || bgmmc_.is_wei_scale_per_k) return 0; if (current_K_iters % bgmmc_.wei_k_blk == 0) return 0; return (bgmmc_.extendable_k || bgmmc_.use_fused_copy_a) ? bgmmc_.wei_k_blk @@ -2514,14 +2458,7 @@ private: const char *bias_ptr_; const void *src_scales_; const void *wei_scales_; - // This pointer is coming from scratchpad and is needed to expand (K/g)xN - // scales to KxN scales as copy_B kernels rely on the full register scales - // for weights decompression feature in case when K_blk is not divisible by - // the scales group size or vice versa. - // TODO: implement the logic at calling copy routines spot to handle the - // unsupported scenario mentioned above to avoid the need in this pointer - // and the overhead around filling that big buffer. - void *wei_scales_tr_; + const void *dst_scales_; const void *dst_scales_inv_; int32_t *s8s8_compensation_ptr_; @@ -2531,10 +2468,10 @@ private: int32_t *reorder_zp_a_comp_ptr_; int32_t zero_point_a_negative_val_; - int32_t zero_point_b_val_; - int32_t zero_point_b_negative_val_; int32_t zero_point_mixed_ab_compensation_component_; int32_t zero_point_c_val_; + int32_t wei_zp_neg_val_; + const void *wei_zp_ptr_; std::vector post_ops_binary_rhs_arg_vec_; int base_brg_ker_idx_; diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp index 510651bbe3..5bb6220dfa 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.cpp @@ -396,7 +396,7 @@ void jit_brgemm_matmul_copy_a_impl_t::copy_K_loop( } // step 3: multiply by zp_b_val - mov(reg_zp_b_neg_val_ptr, ptr[param1 + GET_OFF(zp_b_neg_value_ptr)]); + mov(reg_zp_b_neg_val_ptr, ptr[param1 + GET_OFF(zp_b_neg_val_ptr)]); const auto vmm_zp_b_neg_val = get_vmm_comp_acc(is_ymm_ ? 2 : 1); uni_vbroadcastss(vmm_zp_b_neg_val, ptr[reg_zp_b_neg_val_ptr]); uni_vpmulld(get_vmm_comp_acc(0), get_vmm_comp_acc(0), vmm_zp_b_neg_val); @@ -1892,7 +1892,7 @@ void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::compute_k_loop( } // multiply by zp_b_val - mov(reg_tmp_, ptr[param1 + GET_OFF(zp_b_neg_value_ptr)]); + mov(reg_tmp_, ptr[param1 + GET_OFF(zp_b_neg_val_ptr)]); vbroadcastss(get_zmm_src(0), ptr[reg_tmp_]); vpmulld(zmm_comp_acc_, zmm_comp_acc_, get_zmm_src(0)); @@ -2069,6 +2069,12 @@ template struct jit_brgemm_matmul_copy_a_transposed_impl_t; * This class contains common methods and properties for all copy B kernels. * Now it consists of `load_value` and `decompress_reg` and it's considered * to contain all common methods for copy B kernels. + * Default scenario for weight decompression is: + * 1. load_value() + * 2. apply zero point shift (if needed) + * 3. convert to f32 + * 4. apply scaling (if needed) + * 5. down convert if destination datatype is not f32. */ struct jit_brgemm_matmul_copy_b_common_t : public jit_brgemm_matmul_copy_b_t, public jit_generator_t { @@ -2165,7 +2171,7 @@ protected: } case data_type::bf16: if (is_xf16) { - vmovdqu16(vmm_in, op); + uni_vmovdqu16(vmm_in, op); } else { uni_vpmovzxwd(vmm_in, op); uni_vpslld(vmm_in, vmm_in, 16); @@ -2213,6 +2219,102 @@ protected: } } + /** @brief Loads common zero point value and broadcasts over `zp_vmm` register. + * Handles only per_k and common values. + * @tparam Vmm Vector register type (Zmm, Ymm, etc.) + * @param zp_vmm Vector register to load and broadcast zero point value into + * @param ptr_reg Register containing pointer to zero point value in memory + **/ + template + void load_common_zp_value(const Vmm &zp_vmm, const Xbyak::Reg64 &ptr_reg) { + using Vmm_lower_t = typename vreg_traits_t::Vmm_lower_t; + // Handle only per_k and common values + const bool only_per_k + = conf_->is_wei_zp_per_k && !conf_->is_wei_zp_per_n; + const bool require_load = conf_->has_zero_point_b + && (conf_->is_wei_zp_common || only_per_k); + if (!require_load) return; + + const auto zp_dt = conf_->wei_zp_dt; + const auto tmp_xmm = Xmm(zp_vmm.getIdx()); + const auto vmm_lower = Vmm_lower_t(zp_vmm.getIdx()); + MAYBE_UNUSED(tmp_xmm); + MAYBE_UNUSED(vmm_lower); + const auto &addr = ptr[ptr_reg]; + + const bool need_upconvert = one_of(zp_dt, data_type::s8, data_type::u8, + data_type::u4, data_type::s4); + + if (need_upconvert) { + uni_vpinsrb(tmp_xmm, tmp_xmm, addr, 0); + if (one_of(zp_dt, data_type::s4, data_type::s8)) + uni_vpmovsxbd(tmp_xmm, tmp_xmm); + else + uni_vpmovzxbd(tmp_xmm, tmp_xmm); + + // For 4-bit int need to shift left on 28 bits + if (one_of(zp_dt, data_type::s4, data_type::u4)) + uni_vpslld(tmp_xmm, tmp_xmm, 28); + // Then shift back to the right on 28 bits + if (zp_dt == data_type::u4) vpsrld(tmp_xmm, tmp_xmm, 28); + if (zp_dt == data_type::s4) vpsrad(tmp_xmm, tmp_xmm, 28); + } + + const auto &op = need_upconvert + ? static_cast(tmp_xmm) + : static_cast(addr); + uni_vpbroadcastd(zp_vmm, op); + } + + /** @brief Loads common scale value and broadcasts over `scale_vmm` register. + * Handles only per_k and common values. + * @tparam Vmm Vector register type (Zmm, Ymm, etc.) + * @param scale_vmm Vector register to load and broadcast scale value into + * @param ptr_reg Register containing pointer to scale value in memory + **/ + template + void load_common_scale_value( + const Vmm &scale_vmm, const Xbyak::Reg64 &ptr_reg) { + const bool only_per_k + = conf_->is_wei_scale_per_k && !conf_->is_wei_scale_per_n; + const bool require_scales = conf_->apply_scales_in_buffer_b + && (conf_->is_wei_scale_common || only_per_k); + if (!require_scales) return; + + const auto &scales_dt = conf_->wei_scales_dt; + const auto &addr = ptr[ptr_reg]; + switch (scales_dt) { + case data_type::f32: uni_vbroadcastss(scale_vmm, addr); break; + case data_type::bf16: + vpbroadcastw(scale_vmm, addr); + uni_vpslld(scale_vmm, scale_vmm, 16); + break; + case data_type::f16: vcvtph2psx(scale_vmm, addr); break; + default: assert(!"unsupported wei_scales data type"); + } + } + + /** @brief Helper method to load scales into vector register. + * Supports f32, bf16 and f16 data types. + * @tparam Vmm Vector register type (Zmm, Ymm, etc.) + * @param vmm Vector register to load scale value into + * @param op Operand to load scale value from + **/ + template + void load_scale_value(const Vmm &vmm, const Xbyak::Operand &op, + data_type_t dt, bool is_tail = false) { + const auto masked_vmm = maybe_mask(vmm, is_tail); + switch (dt) { + case data_type::f32: uni_vmovups(masked_vmm, op); break; + case data_type::bf16: + uni_vpmovzxwd(masked_vmm, op); + uni_vpslld(vmm, vmm, 16); + break; + case data_type::f16: vcvtph2ps(masked_vmm, op); break; + default: assert(!"unsupported wei_scales data type"); + } + } + /** * @brief Applies zero point shift to vector register * Shifts input values by subtracting zero point values. @@ -2363,27 +2465,6 @@ protected: downconvert_to_dst_dt(input1, input2, dst_dt); } - /** @brief Helper method to load scales into vector register. - * Supports f32, bf16 and f16 data types. - * @tparam Vmm Vector register type (Zmm, Ymm, etc.) - * @param vmm Vector register to load scale value into - * @param op Operand to load scale value from - **/ - template - void load_scale_value(const Vmm &vmm, const Xbyak::Operand &op, - data_type_t dt, bool is_tail = false) { - const auto masked_vmm = maybe_mask(vmm, is_tail); - switch (dt) { - case data_type::f32: uni_vmovups(masked_vmm, op); break; - case data_type::bf16: - uni_vpmovzxwd(masked_vmm, op); - uni_vpslld(vmm, vmm, 16); - break; - case data_type::f16: vcvtph2ps(masked_vmm, op); break; - default: assert(!"unsupported wei_scales data type"); - } - } - // Common used masks to permute data using opmask_t = const Xbyak::Opmask; opmask_t k3333 = k1; @@ -3237,12 +3318,6 @@ struct jit_brgemm_matmul_copy_b_bf16_t , wei_scales_typesize(conf->wei_scales_dt_sz) , src_stride(conf->copy_B_wei_stride) , tr_src_stride(conf_->LDB * k_blk_step * tr_typesize) - // If scales groups are enabled and are divisible by K_blk, the kernel - // processes one "per_N line" of scales and is called several times. - , wei_scales_N_stride(conf_->wei_scales_k_group_size > 1 - && conf_->gK_and_K_blk_are_divisible - ? 0 - : conf_->N * wei_scales_typesize) , is_src_int4(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4)) , is_dynamic_stride(is_runtime_value(src_stride)) , is_dynamic_N(conf->is_runtime_N) @@ -3250,7 +3325,9 @@ struct jit_brgemm_matmul_copy_b_bf16_t , req_cvtps2bf16(conf->is_bf32 || conf->is_bf16_with_int_wei) , req_zp_b_shift(conf->has_zero_point_b && conf->with_wei_decompression) , req_apply_wei_scales(conf->apply_scales_in_buffer_b) - , typesize_wei_scale(is_src_int4 ? 2 : 1) {} + , is_wei_grouped_over_k( + conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k) + , elems_per_byte(is_src_int4 ? 2 : 1) {} void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); } status_t create_kernel() override { @@ -3267,7 +3344,7 @@ private: enum { k_blk_step = 2, n_blk_step = 16 }; const int typesize, tr_typesize, wei_scales_typesize; - const dim_t src_stride, tr_src_stride, wei_scales_N_stride; + const dim_t src_stride, tr_src_stride; const bool is_src_int4; const bool is_dynamic_stride; const bool is_dynamic_N; @@ -3275,13 +3352,18 @@ private: const bool req_cvtps2bf16; const bool req_zp_b_shift; const bool req_apply_wei_scales; - const dim_t typesize_wei_scale; + const bool is_wei_grouped_over_k; + const dim_t elems_per_byte; constexpr static int reg_src_offs = 0; - constexpr static int reg_tr_src_offs = 8; - constexpr static int reg_current_K_pad_offs_ = 16; - constexpr static int stack_space_needed = 24; + + constexpr static int reg_k_iters_offs_ = 16; + constexpr static int reg_current_K_pad_offs_ = 24; + + constexpr static int reg_K_start_offs_ = 32; + + constexpr static int stack_space_needed = 40; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -3297,6 +3379,7 @@ private: reg64_t reg_copy_block_n_shift = rsi; reg64_t reg_wei_scales = rdx; + reg64_t reg_zp_ptr = r13; reg64_t reg_dynamic_tail = rcx; Xbyak::Reg8 reg8_mask_shift = reg_dynamic_tail.cvt8(); @@ -3360,13 +3443,54 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32( return Vmm(reg_idx); }; - auto load = [this, get_vmm, ncolumns, columns_tail](int blk, int k, int n) { + /** Loads zero points, when is_wei_zp_per_n is set. + * Zeropoints size over N dimension always equals to N. + */ + auto load_zero_point = [this, ncolumns, columns_tail](int n) { + if (!conf_->is_wei_zp_per_n) return; + const bool is_tail = (ncolumns - n) < n_blk_step; + const auto zp_dt = conf_->wei_zp_dt; + const auto zp_dt_sz = types::data_type_size(zp_dt); + const auto elems_per_byte + = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; + const auto offset = n * zp_dt_sz / elems_per_byte; + const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset); + if (is_tail && !isa_has_masks(conf_->isa)) { + load_bytes(vmm_zp_b_shift, addr, columns_tail / elems_per_byte); + load_value(vmm_zp_b_shift, vmm_zp_b_shift, vmm_permd, zp_dt); + } + load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail); + }; + + /** Loads scales, when is_wei_scale_per_n is set. + * Scales size over N dimension always equals to N. + */ + auto load_scales = [this, ncolumns, columns_tail](int n) { + if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b) + return; + + const bool is_tail = (ncolumns - n) < n_blk_step; + const auto &scales_dt = conf_->wei_scales_dt; + const auto scales_dt_sz = types::data_type_size(scales_dt); + const auto offset = n * scales_dt_sz; + const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset); + if (is_tail && !isa_has_masks(conf_->isa)) { + load_bytes( + vmm_wei_scales, addr, columns_tail * wei_scales_typesize); + load_scale_value(vmm_wei_scales, vmm_wei_scales, scales_dt, + /*is_tail=*/false); + } + load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail); + }; + + auto load = [this, get_vmm, ncolumns, columns_tail, load_scales, + load_zero_point](int blk, int k, int n) { auto src_reg = get_vmm(blk, k % k_blk_step); const bool is_tail = ncolumns - n < n_blk_step; auto src_load = maybe_mask(src_reg, is_tail); const auto offset = ((is_dynamic_stride ? 0 : k * src_stride) + (n * typesize)) - / typesize_wei_scale; + / elems_per_byte; const auto reg_src_load = is_dynamic_stride && k % 2 != 0 ? reg_src_load_1 : reg_src; auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset); @@ -3379,19 +3503,39 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32( load_value( src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail); } - - const auto scales_offset - = (is_dynamic_stride ? 0 : k * wei_scales_N_stride) - + n * wei_scales_typesize; - const auto scales_addr - = maybe_EVEX_compress_addr(reg_wei_scales, scales_offset); - if (req_apply_wei_scales) - load_scale_value( - vmm_wei_scales, scales_addr, conf_->wei_scales_dt, is_tail); - decompress_and_downcvt_reg(src_load, vmm_zp_b_shift, vmm_wei_scales, + load_zero_point(n); + load_scales(n); + decompress_and_downcvt_reg(src_reg, vmm_zp_b_shift, vmm_wei_scales, conf_->orig_wei_dt, conf_->wei_dt); }; + /** Stores half of the block using mask for the case when vnni_granularity == 2 */ + auto store_half_block = [&](const Vmm &src_vmm0, const Vmm &src_vmm1, + const Xbyak::Address &store_addr) { + const auto zmm1 = zmm(src_vmm1.getIdx()); + const auto zmm0 = zmm(src_vmm0.getIdx()); + uni_vxorps(zmm1, zmm1, zmm1); + //if k % 2 == 1 then save only odd indices + // otherwise: using only even indices + Label even_k, end_permute; + mov(reg_tmp, ptr[rsp + reg_K_start_offs_]); + test(reg_tmp, 1); + jz(even_k, T_NEAR); + vinsertf64x4(zmm0, zmm1, ymm(src_vmm0.getIdx()), 1); + vpermw(zmm0, vmm_permw, zmm0); + uni_vmovdqu16(store_addr | kAAAA, zmm0); + jmp(end_permute); + L(even_k); + vinsertf64x4(zmm0, zmm1, ymm(src_vmm0.getIdx()), 0); + vpermw(zmm0, vmm_permw, zmm0); + uni_vmovdqu16(store_addr, zmm0); + L(end_permute); + }; + + // The case when it's required to store half block + // When grouped over K weights and K == 1 + const auto kernel_early_stop = is_wei_grouped_over_k && nrows == 1; + int iter = 0; int n_iters; if (is_dynamic_N || do_N_loop) { @@ -3399,11 +3543,13 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32( } else { n_iters = conf_->wei_n_blk; } + for_(int k = 0; k < nrows; k += k_blk_step) for (int n = 0; n < n_iters; n += n_blk_step) { const int k_blk = k / k_blk_step; const dim_t tr_src_off = k_blk * tr_src_stride + n * k_blk_step * tr_typesize; + const auto store_addr = maybe_EVEX_compress_addr(reg_tr_src, tr_src_off); const auto store_addr_ymm1 @@ -3431,18 +3577,22 @@ void jit_brgemm_matmul_copy_b_bf16_t::copy_2x32( load(blk_idx, k, n); - if (nrows - k >= k_blk_step) { - load(blk_idx, k + 1, n); - if (is_superset(conf_->isa, avx512_core)) { - const auto src_ymm1 = ymm(src_vmm1.getIdx()); - vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1); - } - } - if (!is_superset(conf_->isa, avx512_core)) { - uni_vxorps(src_vmm1, src_vmm1, src_vmm1); + // Store only half block + if (kernel_early_stop) { + store_half_block(src_vmm0, src_vmm1, store_addr); + iter++; + continue; } + // Load second K half blk and downconvert if required. + if (nrows - k >= k_blk_step) + load(blk_idx, k + 1, n); + else + uni_vxorps(src_vmm1, src_vmm1, src_vmm1); + if (is_superset(conf_->isa, avx512_core)) { + const auto src_ymm1 = ymm(src_vmm1.getIdx()); + vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1); vpermw(src_zmm0, vmm_permw, src_zmm0); uni_vmovups(store_addr, src_zmm0); } else { @@ -3474,14 +3624,19 @@ void jit_brgemm_matmul_copy_b_bf16_t::init_masks() { mov(reg_tmp, reinterpret_cast(bf16_vnni_permute)); vmovdqa64(vmm_permw, ptr[reg_tmp]); + if (isa_has_masks(conf_->isa)) { + // 64-bit mask is also used when is_wei_[zp\scales]_per_k + mov(reg_tmp, 0xAAAAAAAAAAAAAAAA); + kmovq(kAAAA, reg_tmp); + mov(reg_tmp, 0x5555555555555555); + kmovq(k5555, reg_tmp); + } + if (is_src_int4) { alignas(64) static constexpr const uint32_t int4_permute[16] = {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15}; mov(reg_tmp, reinterpret_cast(int4_permute)); vmovdqa32(vmm_permd, ptr[reg_tmp]); - - kmovx(kAAAA, 0xaaaa); - kmovx(k5555, 0x5555); } } } @@ -3575,22 +3730,47 @@ void jit_brgemm_matmul_copy_b_bf16_t::generate() { mov(reg_src, ptr[param1 + GET_OFF(src)]); mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); + mov(ptr[rsp + reg_tr_src_offs], reg_tr_src); mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]); + // Due to lack of registers save k_iters and k_pad into stack space + mov(reg_tmp, ptr[param1 + GET_OFF(current_K_iters)]); + mov(ptr[rsp + reg_k_iters_offs_], reg_tmp); + mov(reg_tmp, ptr[param1 + GET_OFF(current_K_pad)]); + mov(ptr[rsp + reg_current_K_pad_offs_], reg_tmp); + mov(reg_tmp, ptr[param1 + GET_OFF(current_K_start)]); + mov(ptr[rsp + reg_K_start_offs_], reg_tmp); + mov(reg_tmp, 0); + if (is_dynamic_stride) { mov(reg_src_stride, ptr[param1 + GET_OFF(dynamic_src_stride)]); mov(reg_src_stride_x2, ptr[param1 + GET_OFF(dynamic_src_stride)]); shl(reg_src_stride_x2, 1); } - if (req_zp_b_shift) { - mov(reg_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]); - uni_vpbroadcastd(vmm_zp_b_shift, ptr[reg_tmp]); - } + + mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]); + load_common_zp_value(vmm_zp_b_shift, reg_zp_ptr); + load_common_scale_value(vmm_wei_scales, reg_wei_scales); init_masks(); auto compute_K_loop_body = [&](const reg64_t ®_K, int ncolumns, bool is_N_tail, bool zeropad) { + // Compute special K-loop for per-k attributes + // Only when k_group_size < k_blk_step + // Otherwise default K-loop is used + if (is_wei_grouped_over_k) { + const int k_group_size = conf_->is_wei_zp_per_k + ? conf_->wei_zp_k_gsize + : conf_->wei_scales_k_gsize; + if (k_group_size < k_blk_step) { + if (zeropad) return; + copy_block( + k_group_size, ncolumns, is_N_tail, /*zeropad= */ false); + return; + } + } + const int k_unroll = 8; Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done; @@ -3601,10 +3781,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::generate() { copy_block(k_unroll * k_blk_step, ncolumns, is_N_tail, zeropad); if (!zeropad && !is_dynamic_stride) - add(reg_src, - (k_unroll * k_blk_step * src_stride) / typesize_wei_scale); - if (!zeropad && req_apply_wei_scales) - add(reg_wei_scales, k_unroll * k_blk_step * wei_scales_N_stride); + add(reg_src, (k_unroll * k_blk_step * src_stride) / elems_per_byte); add(reg_tr_src, k_unroll * tr_src_stride); sub(reg_K, k_unroll * k_blk_step); @@ -3617,9 +3794,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::generate() { copy_block(k_blk_step, ncolumns, is_N_tail, zeropad); if (!zeropad && !is_dynamic_stride) - add(reg_src, (k_blk_step * src_stride) / typesize_wei_scale); - if (!zeropad && req_apply_wei_scales) - add(reg_wei_scales, k_blk_step * wei_scales_N_stride); + add(reg_src, (k_blk_step * src_stride) / elems_per_byte); add(reg_tr_src, tr_src_stride); sub(reg_K, k_blk_step); @@ -3645,9 +3820,7 @@ void jit_brgemm_matmul_copy_b_bf16_t::generate() { // 'param1' register (rcx on Windows) re-written in compute_K_loop_body // so we need to read and keep 'current_K_pad' parameter in stack before // the call - mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_pad)]); - mov(ptr[rsp + reg_current_K_pad_offs_], reg_K_iters); - mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); + mov(reg_K_iters, ptr[rsp + reg_k_iters_offs_]); compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, false); mov(reg_K_iters, ptr[rsp + reg_current_K_pad_offs_]); compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, true); @@ -3696,13 +3869,7 @@ struct jit_brgemm_matmul_copy_b_f32_t , src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1) , wei_scales_typesize_(conf_->wei_scales_dt_sz) , src_stride_(conf_->copy_B_wei_stride) - , tr_src_stride_(conf_->LDB * typesize_out_) - // If scales groups are enabled and are divisible by K_blk, the kernel - // processes one "per_N line" of scales and is called several times. - , wei_scales_N_stride_(conf_->wei_scales_k_group_size > 1 - && conf_->gK_and_K_blk_are_divisible - ? 0 - : conf_->N * wei_scales_typesize_) {} + , tr_src_stride_(conf_->LDB * typesize_out_) {} void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); } status_t create_kernel() override { @@ -3720,7 +3887,7 @@ private: const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_; const size_t typesize_in_, src_elems_per_byte_, wei_scales_typesize_; const size_t typesize_out_ = sizeof(float); - dim_t src_stride_, tr_src_stride_, wei_scales_N_stride_; + dim_t src_stride_, tr_src_stride_; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -3731,6 +3898,7 @@ private: reg64_t reg_tmp = r15; reg32_t regw_tmp = r15d; reg64_t reg_wei_scales = rdx; + reg64_t reg_zp_ptr = r11; Vmm vmm_zero = Vmm(0); Vmm vmm_wei_scales = Vmm(1); @@ -3777,10 +3945,39 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( else load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail); - const auto scales_addr = maybe_EVEX_compress_addr(reg_wei_scales, - k * wei_scales_N_stride_ + n * wei_scales_typesize_); decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift, - scales_addr, conf_->orig_wei_dt); + vmm_wei_scales, conf_->orig_wei_dt); + }; + + /** Loads zero points, when is_wei_zp_per_n is set. + * Zeropoints size over N dimension always equals to N. + */ + auto load_zero_point = [this, ncolumns](int n) { + if (!conf_->is_wei_zp_per_n) return; + + const bool is_tail = (ncolumns - n) < simd_w_; + const auto zp_dt = conf_->wei_zp_dt; + const auto zp_dt_sz = types::data_type_size(zp_dt); + const auto elems_per_byte + = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; + const auto offset = n * zp_dt_sz / elems_per_byte; + const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset); + load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail); + }; + + /** Loads scales, when is_wei_scale_per_n is set. + * Scales size over N dimension always equals to N. + */ + auto load_scales = [this, ncolumns](int n) { + if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b) + return; + + const bool is_tail = (ncolumns - n) < simd_w_; + const auto &scales_dt = conf_->wei_scales_dt; + const auto scales_dt_sz = types::data_type_size(scales_dt); + const auto offset = n * scales_dt_sz; + const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset); + load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail); }; const int columns_tail = ncolumns % simd_w_; @@ -3811,6 +4008,8 @@ void jit_brgemm_matmul_copy_b_f32_t::copy_16_x_n_block( continue; } + load_zero_point(n); + load_scales(n); const int blk_idx = iter % max_regs_available; load(blk_idx, k, n); @@ -3833,8 +4032,6 @@ void jit_brgemm_matmul_copy_b_f32_t::compute_k_loop(int ncolumns) { copy_16_x_n_block(unroll, ncolumns); add(reg_src, (unroll * src_stride_) / src_elems_per_byte_); add(reg_tr_src, unroll * tr_src_stride_); - if (req_apply_wei_scales_) - add(reg_wei_scales, unroll * wei_scales_N_stride_); sub(reg_K_iters, unroll); jmp(K_start_label, T_NEAR); @@ -3857,6 +4054,7 @@ void jit_brgemm_matmul_copy_b_f32_t::generate() { mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]); + mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]); kmovw(kFFFF, 0xffff); // 1111111111111111 if (is_src_int4_ || is_src_f4_) { alignas(64) static constexpr const uint32_t int4_permute[16] @@ -3888,10 +4086,8 @@ void jit_brgemm_matmul_copy_b_f32_t::generate() { vmovdqa32(vmm_f4_lut, ptr[reg_tmp]); } - if (req_zp_b_shift_) { - mov(reg_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]); - uni_vpbroadcastd(vmm_zp_b_shift, ptr[reg_tmp]); - } + load_common_zp_value(vmm_zp_b_shift, reg_zp_ptr); + load_common_scale_value(vmm_wei_scales, reg_wei_scales); Label done; if (conf_->N_tail > 0) { @@ -3925,6 +4121,8 @@ struct jit_brgemm_matmul_copy_b_transposed_t , wei_scales_typesize_(conf_->wei_scales_dt_sz) , vnni_granularity_(data_type_vnni_granularity(conf_->wei_dt)) , k_blk_step_(vlen_ / tr_typesize_) + , is_wei_grouped_over_k_( + conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k) , do_compute_compensation_( conf_->has_zero_point_a || conf_->s8s8_compensation_required) , is_bf32_(conf->is_bf32) @@ -3938,8 +4136,6 @@ struct jit_brgemm_matmul_copy_b_transposed_t , req_zp_b_shift_( conf_->has_zero_point_b && conf_->with_wei_decompression) , req_apply_wei_scales_(conf_->apply_scales_in_buffer_b) - , single_wei_scales_value_(conf_->wei_scales_k_group_size > 1 - && conf_->gK_and_K_blk_are_divisible) , avx512_core_dot_product_( do_compute_compensation_ && !isa_has_int8_vnni(conf->isa)) // See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` used. @@ -3960,10 +4156,6 @@ struct jit_brgemm_matmul_copy_b_transposed_t : 0))) , src_stride_(conf_->copy_B_wei_stride) , tr_src_stride_(conf_->LDB * vnni_granularity_ * tr_typesize_) - // If scales groups are enabled and are divisible by K_blk, the kernel - // processes a single scale value and is called several times. - , wei_scales_K_stride_((single_wei_scales_value_ ? 1 : conf_->K) - * wei_scales_typesize_) , src_elems_per_byte_(is_src_int4_ ? 2 : 1) , is_dynamic_N_(conf->is_runtime_N) {} @@ -3991,6 +4183,7 @@ private: const int wei_scales_typesize_; const int vnni_granularity_; const int k_blk_step_; + const bool is_wei_grouped_over_k_; const bool do_compute_compensation_; const bool is_bf32_; const bool is_bf16_with_int_wei_; @@ -4000,14 +4193,12 @@ private: const bool req_s8s8_comp_; const bool req_zp_b_shift_; const bool req_apply_wei_scales_; - const bool single_wei_scales_value_; const bool avx512_core_dot_product_; const bool use_fp16_instructions_; const bool use_bf16_instructions_; const int max_tmp_idx; - const dim_t src_stride_, tr_src_stride_, wei_scales_K_stride_, - src_elems_per_byte_; + const dim_t src_stride_, tr_src_stride_, src_elems_per_byte_; const bool is_dynamic_N_; constexpr static int ldb_step_idx_offs = 0; @@ -4016,7 +4207,7 @@ private: reg64_t reg_src_base = rax; reg64_t reg_tr_src_base = rbx; reg64_t reg_comp_ptr = rdx; - reg64_t reg_wei_scales_base = rsi; + reg64_t reg_zp_ptr = rdx; reg64_t reg_K_iters = r8; reg64_t reg_N_iters = r9; @@ -4025,7 +4216,7 @@ private: reg64_t reg_zp_comp_ptr = r12; reg64_t reg_zp_a_neg_val_ptr = r13; reg64_t reg_K_start = r14; - reg64_t reg_wei_scales = rdx; + reg64_t reg_wei_scales = rsi; reg64_t regq_tmp = r15; reg32_t regw_tmp = r15d; @@ -4096,6 +4287,51 @@ private: return next_row_idx < num_rows || dynamic_tail; } + /** + * Loads zero point value and broadcasts it over Vmm register. + * Supported data types: s4/u4/s8/u8/s32. + * + * @param n N-dimension local index. + * @param is_tail Bool flag indicating if tail is processing. + */ + void load_zero_point(int n, bool is_tail) { + if (!conf_->is_wei_zp_per_n) return; + const auto zp_dt = conf_->wei_zp_dt; + const auto zp_dt_sz = types::data_type_size(zp_dt); + const auto elems_per_byte + = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; + const auto offset = n * zp_dt_sz / elems_per_byte; + const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset); + + const bool is_odd_index = n % elems_per_byte == 1; + const auto tmp_xmm = Xmm(vmm_zp_b_val.getIdx()); + MAYBE_UNUSED(tmp_xmm); + MAYBE_UNUSED(is_odd_index); + + const bool need_upconvert = one_of(zp_dt, data_type::s8, data_type::u8, + data_type::s4, data_type::u4); + if (need_upconvert) { + uni_vpinsrb(tmp_xmm, tmp_xmm, addr, 0); + if (one_of(zp_dt, data_type::s8, data_type::s4)) + uni_vpmovsxbd(tmp_xmm, tmp_xmm); + else + uni_vpmovzxbd(tmp_xmm, tmp_xmm); + + // 4-bit integer must be shifted left depending + // which element of 2 is required + if (one_of(zp_dt, data_type::s4, data_type::u4)) + uni_vpslld(tmp_xmm, tmp_xmm, 28 - is_odd_index * 4); + // Then shift back to the right on 28 bits + if (zp_dt == data_type::u4) vpsrld(tmp_xmm, tmp_xmm, 28); + if (zp_dt == data_type::s4) vpsrad(tmp_xmm, tmp_xmm, 28); + } + const auto &op = need_upconvert + ? static_cast(tmp_xmm) + : static_cast(addr); + const auto masked_vmm = maybe_mask(vmm_zp_b_val, is_tail); + uni_vpbroadcastd(masked_vmm, op); + } + /** * Loads scales and broadcasts it over Vmm register. * Supported data types: f32, bf16, f16. @@ -4107,65 +4343,42 @@ private: if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b) return; - const auto offset = n * wei_scales_K_stride_; - - if (wei_scales_K_stride_ == wei_scales_typesize_) { - // A single scale per kernel case. - - // Enable broadcast address for f16 to avoid vmm manipulations. - const auto wei_scales_addr = EVEX_compress_addr(reg_wei_scales, - offset, conf_->wei_scales_dt == data_type::f16); - switch (conf_->wei_scales_dt) { - case data_type::f32: - uni_vbroadcastss(vmm_wei_scales, wei_scales_addr); - break; - case data_type::bf16: - vpbroadcastw(vmm_wei_scales, wei_scales_addr); - uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16); - break; - case data_type::f16: - vcvtph2psx(vmm_wei_scales, wei_scales_addr); - break; - default: assert(!"unsupported wei_scales data type"); - } - } else { - // A broadcasted ahead-of-time scales case. - - // This branch assumes that `wei_scales` have been transposed outside - // before passing their values in here. This is done in - // `get_wei_scales_ptr` function in brgemm_matmul_ctx_t. - // - // It's important that even when groups are specified, the amount of - // memory allocated is KxN, thus, there's an over-use of memory, but - // such usage allows us to simplify the kernel logic and just load - // weights into a full vector register. - const auto wei_scales_addr - = EVEX_compress_addr(reg_wei_scales, offset); - const auto vmm_wei_scales_masked - = maybe_mask(vmm_wei_scales, is_tail); - - switch (conf_->wei_scales_dt) { - case data_type::f32: - uni_vmovups(vmm_wei_scales_masked, wei_scales_addr); - break; - case data_type::bf16: - uni_vpmovzxwd(vmm_wei_scales_masked, wei_scales_addr); - uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16); - break; - case data_type::f16: - vcvtph2ps(vmm_wei_scales_masked, wei_scales_addr); - break; - default: assert(!"unsupported wei_scales data type"); - } + const auto &scales_dt = conf_->wei_scales_dt; + const auto &scales_dt_sz = conf_->wei_scales_dt_sz; + const auto offset = n * scales_dt_sz; + const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail); + const auto addr = EVEX_compress_addr( + reg_wei_scales, offset, scales_dt == data_type::f16); + vpxord(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales); + switch (scales_dt) { + case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break; + case data_type::bf16: + vpbroadcastw(masked_vmm, addr); + uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16); + break; + case data_type::f16: vcvtph2psx(vmm_wei_scales, addr); break; + default: assert(!"unsupported wei_scales data type"); } } + /** Stores half of the block using mask for the case when vnni_granularity == 2 */ + void store_half_block(const Zmm &r, const Xbyak::Address &store_addr) { + Label even_k, end_permute; + test(reg_K_start, 1); + jz(even_k, T_NEAR); + // Shift left by 16 bytes to store odd indices + uni_vpslld(r, r, 16); + vmovdqu16(store_addr | kAAAA, r); + jmp(end_permute); + L(even_k); + // Store even indices, odd set to zero. + vmovdqu16(store_addr, r); + L(end_permute); + } + void generate() override; }; -// This method applies scales for weights decompression scenario. Given it's the -// transposed kernel, B is in column-major format, but scales in the library are -// in row-major format. template void jit_brgemm_matmul_copy_b_transposed_t::init_tail_mask( const int columns_tail, const bool use_int4_mask) { @@ -4232,6 +4445,24 @@ bool jit_brgemm_matmul_copy_b_transposed_t::preload_int4(const Xmm &xmm_in, } return true; } + + // The case when the kernel is grouped over K and need to load odd or even columns + const auto preload_for_k_1_blk + = is_src_int4_ && is_wei_grouped_over_k_ && columns_tail == 1; + if (preload_for_k_1_blk) { + // Unconditionally load 1 byte, then shift if odd index + load_bytes(xmm_in, addr, 1); + Label load_done, even_k; + test(reg_K_start, 1); + jz(even_k, T_NEAR); + vpsrlq(xmm_in, xmm_in, 4); + jmp(load_done); + L(even_k); + vpsllq(xmm_in, xmm_in, 4); + vpsrlq(xmm_in, xmm_in, 4); + L(load_done); + return true; + } return false; } @@ -4288,6 +4519,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( if (is_src_int4_) init_tail_mask(columns_tail, true); load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail); if (is_src_int4_) init_tail_mask(columns_tail, false); + load_zero_point(i, is_tail); load_scales(i, is_tail); decompress_reg(src_reg_masked, vmm_zp_b_val, vmm_wei_scales, conf_->orig_wei_dt); @@ -4317,6 +4549,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( load_value(src_reg_next, src_op, vmm_permd, conf_->orig_wei_dt, is_tail); if (is_src_int4_) init_tail_mask(columns_tail, false); + load_zero_point(i, is_tail); load_scales(i, is_tail); decompress_reg(src_next_masked, vmm_zp_b_val, vmm_wei_scales, conf_->orig_wei_dt); @@ -4347,13 +4580,12 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( } const auto is_tail = columns_tail > 0; - auto src_load = is_tail ? src_reg | kTail | T_z : src_reg; const auto src_offset = (i * src_stride_) / src_elems_per_byte_; const auto addr = EVEX_compress_addr(reg_src, src_offset); + auto src_masked_reg = maybe_mask(src_reg, is_tail); if ((conf_->is_f16_with_int_wei || conf_->is_f32_with_int_wei) && conf_->wei_dt == data_type::f32) { const auto xmm_preload = Xmm(src_reg.getIdx()); - MAYBE_UNUSED(xmm_preload); const bool preloaded_int4 = preload_int4( xmm_preload, i, columns_tail, is_tail, src_offset); @@ -4363,28 +4595,38 @@ void jit_brgemm_matmul_copy_b_transposed_t::copy_row_x_col( if (is_src_int4_) init_tail_mask(columns_tail, true); load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail); if (is_src_int4_) init_tail_mask(columns_tail, false); + load_zero_point(i, is_tail); load_scales(i, is_tail); - decompress_reg(maybe_mask(src_reg, is_tail), vmm_zp_b_val, - vmm_wei_scales, conf_->orig_wei_dt); + decompress_reg(src_masked_reg, vmm_zp_b_val, vmm_wei_scales, + conf_->orig_wei_dt); } else if (use_fp16_instructions_) { if (conf_->isa == avx512_core_fp16) { - vcvtph2psx(src_load, addr); + vcvtph2psx(src_masked_reg, addr); } else { - vcvtph2ps(src_load, addr); + vcvtph2ps(src_masked_reg, addr); } } else if (use_bf16_instructions_) { // Upconvert: load 16 bits and move them 16 bits left. - uni_vpmovzxwd(src_load, addr); - uni_vpslld(src_load, src_load, 16); + uni_vpmovzxwd(src_masked_reg, addr); + uni_vpslld(src_masked_reg, src_masked_reg, 16); } else { - vmovdqu8(src_load, addr); + vmovdqu8(src_masked_reg, addr); } L(load_done); }; - auto store = [this](Zmm r, int i) { + auto store = [this, columns_tail, ncolumns, cur_k_blk_step](Zmm r, int i) { auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride_); - vmovups(addr, r); + if (is_wei_grouped_over_k_) { + const bool is_tail = columns_tail > 0 && ncolumns < cur_k_blk_step; + if (is_tail && i >= columns_tail) return; + if (vnni_granularity_ == 2 && ncolumns == 1) + store_half_block(r, addr); + else + vmovups(addr, r); + } else { + vmovups(addr, r); + } }; auto transpose16x8 = [&](int base_idx) { @@ -4662,7 +4904,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::compute_K_loop(bool is_N_tail, mov(reg_src, reg_src_base); mov(reg_tr_src, reg_tr_src_base); - if (req_apply_wei_scales_) mov(reg_wei_scales, reg_wei_scales_base); + if (curr_K_tail > 0) { cmp(reg_K_iters, k_blk_step_); jl(K_loop_tail_or_done, T_NEAR); @@ -4672,8 +4914,6 @@ void jit_brgemm_matmul_copy_b_transposed_t::compute_K_loop(bool is_N_tail, copy_row_x_col(nrows, k_blk_step_); add(reg_src, (k_blk_step_ * typesize_) / src_elems_per_byte_); add(reg_tr_src, k_blk_step_ / vnni_granularity_ * tr_src_stride_); - if (req_apply_wei_scales_ && !single_wei_scales_value_) - add(reg_wei_scales, k_blk_step_ * wei_scales_typesize_); sub(reg_K_iters, k_blk_step_); cmp(reg_K_iters, k_blk_step_); @@ -4750,8 +4990,20 @@ void jit_brgemm_matmul_copy_b_transposed_t::compute_N_loop( add(reg_tr_src_base, n_blk_step_ * vnni_granularity_ * tr_typesize_); } - if (req_apply_wei_scales_) - add(reg_wei_scales_base, n_blk_step_ * wei_scales_K_stride_); + if (conf_->is_wei_scale_per_n) { + const auto &scales_dt_sz = conf_->wei_scales_dt_sz; + const auto offset = n_blk_step_ * scales_dt_sz; + add(reg_wei_scales, offset); + } + + if (conf_->is_wei_zp_per_n) { + const auto &zp_dt = conf_->wei_zp_dt; + const auto zp_dt_sz = types::data_type_size(zp_dt); + const auto elems_per_byte + = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; + const auto offset = n_blk_step_ * zp_dt_sz / elems_per_byte; + add(reg_zp_ptr, offset); + } if (req_zp_comp_) add(reg_zp_comp_ptr, comp_shift_); if (req_s8s8_comp_) add(reg_comp_ptr, comp_shift_); @@ -4780,20 +5032,20 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { mov(regq_tmp.cvt16(), 1); vpbroadcastw(vmm_ones_words, regq_tmp.cvt16()); } - if (req_zp_b_shift_) { - mov(regq_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]); - uni_vpbroadcastd(vmm_zp_b_val, ptr[regq_tmp]); - } mov(reg_src_base, ptr[param1 + GET_OFF(src)]); mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]); mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]); mov(reg_N_iters, ptr[param1 + GET_OFF(current_N_blk)]); - mov(reg_wei_scales_base, ptr[param1 + GET_OFF(wei_scales_ptr)]); + mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]); + mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]); + mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]); if (!is_ymm_) { - kmovw(k5555, 0x5555); - kmovw(kAAAA, 0xaaaa); + // 64-bit mask is also used when is_wei_[zp\scales]_per_k + kmovq(kAAAA, 0xAAAAAAAAAAAAAAAA); + kmovq(k5555, 0x5555555555555555); + kmovw(k3333, 0x3333); kmovw(kCCCC, 0xcccc); kmovw(k0F0F, 0x0f0f); @@ -4806,12 +5058,21 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { vmovdqa32(vmm_permd, ptr[regq_tmp]); } + load_common_zp_value(vmm_zp_b_val, reg_zp_ptr); + load_common_scale_value(vmm_wei_scales, reg_wei_scales); + const dim_t N_chunk_elems = conf_->N_chunk_elems; assert(N_chunk_elems % n_blk_step_ == 0 || N_chunk_elems == conf_->N); UNUSED(N_chunk_elems); - const auto K_blk_tail = nstl::min(conf_->K, conf_->K_blk) % k_blk_step_; - const auto K_tail_tail = (conf_->K % conf_->K_blk) % k_blk_step_; + const auto &k_blk = conf_->K_blk; + const auto K_blk_tail = nstl::min(conf_->K, k_blk) % k_blk_step_; + const auto K_tail_tail = (conf_->K % k_blk) % k_blk_step_; + + const auto grouped_k = is_wei_grouped_over_k_ + ? (conf_->is_wei_zp_per_k ? conf_->wei_zp_k_gsize + : conf_->wei_scales_k_gsize) + : 0; auto compute_body = [&](bool is_first_K_iter, bool is_last_K_iter) { if (is_last_K_iter) { @@ -4828,10 +5089,15 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { } } + if (is_wei_grouped_over_k_ && grouped_k < k_blk_step_) { + compute_N_loop(grouped_k, is_first_K_iter, is_last_K_iter); + return; + } + Label compute_body_done; if (conf_->K_tail > 0 && K_blk_tail != K_tail_tail) { Label not_K_tail; - cmp(reg_K_iters, conf_->K_blk); + cmp(reg_K_iters, k_blk); je(not_K_tail, T_NEAR); compute_N_loop(K_tail_tail, is_first_K_iter, is_last_K_iter); jmp(compute_body_done, T_NEAR); @@ -4857,8 +5123,7 @@ void jit_brgemm_matmul_copy_b_transposed_t::generate() { mov(regq_tmp, 1); uni_vpbroadcastb(vmm_comp_mul, regq_tmp.cvt8()); - const auto last_K_threshold - = rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk; + const auto last_K_threshold = rnd_up(conf_->K, k_blk) - k_blk; Label not_first, not_first_not_last; cmp(reg_K_start, 0); jne(not_first, T_NEAR); @@ -4909,19 +5174,15 @@ struct jit_brgemm_matmul_copy_b_cvt_bf16_t , src_stride_( (conf->LDB * k_blk_step * typesize_) / src_elems_per_byte_) , tr_src_stride_(conf_->LDB * k_blk_step * tr_typesize_) - // If scales groups are enabled and are divisible by K_blk, the kernel - // processes one "per_N line" of scales and is called several times. - , wei_scales_N_stride_(conf_->wei_scales_k_group_size > 1 - && conf_->gK_and_K_blk_are_divisible - ? 0 - : conf_->N * wei_scales_typesize_) , req_zp_b_shift_( conf_->has_zero_point_b && conf_->with_wei_decompression) , req_apply_wei_scales_(conf_->apply_scales_in_buffer_b) - , reserved_regs_(req_apply_wei_scales_ ? 5 - : is_src_int4_ ? 2 - : req_zp_b_shift_ ? 1 - : 0) {} + , reserved_regs_(req_apply_wei_scales_ ? 6 + : req_zp_b_shift_ ? 4 + : is_src_int4_ ? 1 + : 0) + , is_wei_grouped_over_k_( + conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k) {} void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); } status_t create_kernel() override { @@ -4939,11 +5200,11 @@ private: enum { k_blk_step = 2, n_blk_step = 16 }; const int typesize_, tr_typesize_, wei_scales_typesize_; const bool is_src_int4_; - const dim_t src_elems_per_byte_, src_stride_, tr_src_stride_, - wei_scales_N_stride_; + const dim_t src_elems_per_byte_, src_stride_, tr_src_stride_; const bool req_zp_b_shift_; const bool req_apply_wei_scales_; const int reserved_regs_; + const bool is_wei_grouped_over_k_; reg64_t reg_src = rax; reg64_t reg_tr_src = rbx; @@ -4957,11 +5218,15 @@ private: reg64_t reg_src_back = r12; reg64_t reg_tr_src_back = r13; - Vmm vmm_zp_b_val = Vmm(0); - Vmm vmm_permd = Vmm(1); - Vmm vmm_wei_scales0 = Vmm(2); - Vmm vmm_wei_scales1 = Vmm(3); - Vmm vmm_tmp = Vmm(4); + reg64_t reg_wei_zp = r14; + reg64_t reg_k_start = r15; + + Vmm vmm_permd = Vmm(0); + Vmm vmm_zp_b_val0 = Vmm(1); + Vmm vmm_zp_b_val1 = Vmm(2); + Vmm vmm_tmp = Vmm(3); + Vmm vmm_wei_scales0 = Vmm(4); + Vmm vmm_wei_scales1 = Vmm(5); Vmm get_vmm(const int blk, const int idx) { const int max_isa_regs = isa_num_vregs(conf_->isa); @@ -4974,9 +5239,48 @@ private: } void init_masks(); - void get_scales(const int blk, const int k, const int n, - const bool is_n_tail, const bool is_k_tail); + void get_wei_scales( + const int n, const bool is_n_tail, const bool is_k_tail); + void get_zero_points(const int n, const bool is_tail, const bool is_k_tail); void copy_block(const int nrows, const int ncolumns, bool zeropad); + + /** Adjust strides for grouped over k weights + * k_blk_step is const 2. This case handles + * nrows = 1 + * Move tr_src pointer to the beginning of the block 2x32 + * if the k_start % 2 = 1 is odd. + **/ + void maybe_update_strides(int nrows) { + if (is_wei_grouped_over_k_ && nrows < k_blk_step) { + Label even_k; + test(reg_k_start, 1); + jz(even_k, T_NEAR); + // Shift back to start of the vnni block + sub(reg_src, typesize_ / src_elems_per_byte_); + L(even_k); + } + } + + void save_half_block(const int blk_idx, const Xbyak::Address &store_addr) { + const auto src0 = get_vmm(blk_idx, 0); + const auto zmm0 = zmm(src0.getIdx()); + //if k % 2 == 1 then save only odd indices + // otherwise: only even using masks + Label even_k, end_permute; + test(reg_k_start, 1); + jz(even_k, T_NEAR); + // Odd indices case + vmovdqu16(store_addr | kAAAA, zmm0); + jmp(end_permute); + L(even_k); + // Clean the whole block before storing + uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp); + vmovdqu16(store_addr, vmm_tmp); + // Store only even indices + vmovdqu16(store_addr | k5555 | T_z, zmm0); + L(end_permute); + } + void generate() override; }; @@ -4991,46 +5295,35 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::init_masks() { mov(reg_tmp, reinterpret_cast(bf16_vnni_permute)); vmovdqa32(vmm_permd, ptr[reg_tmp]); - mov(regw_tmp, 0x5555); - kmovw(k5555, regw_tmp); - mov(regw_tmp, 0xaaaa); - kmovw(kAAAA, regw_tmp); + // 64-bit mask is also used when is_wei_[zp\scales]_per_k + mov(reg_tmp, 0xAAAAAAAAAAAAAAAA); + kmovq(kAAAA, reg_tmp); + mov(reg_tmp, 0x5555555555555555); + kmovq(k5555, reg_tmp); } } +/** Loads scales into 2 registers and permutes it. +* Since groups over K-dimension are handled outside the kernel +* loading is performed for the same address for both registers. +*/ template -void jit_brgemm_matmul_copy_b_cvt_bf16_t::get_scales(const int blk, - const int k, const int n, const bool is_n_tail, const bool is_k_tail) { +void jit_brgemm_matmul_copy_b_cvt_bf16_t::get_wei_scales( + const int n, const bool is_n_tail, const bool is_k_tail) { - if (!req_apply_wei_scales_) return; + if (!req_apply_wei_scales_ || !conf_->is_wei_scale_per_n) return; const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail); const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail); - const auto base_offset = [&](int k) { - return k * wei_scales_N_stride_ + n * wei_scales_typesize_; - }; - auto wei_scales_addr0 - = maybe_EVEX_compress_addr(reg_wei_scales, base_offset(k)); - auto wei_scales_addr1 - = maybe_EVEX_compress_addr(reg_wei_scales, base_offset(k + 1)); + const auto base_offset + = [&](int n_idx) { return n_idx * wei_scales_typesize_; }; - const auto load_scales = [&](const Vmm &vmm, const Address &addr) { - switch (conf_->wei_scales_dt) { - case data_type::f32: uni_vmovups(vmm, addr); break; - case data_type::bf16: - uni_vpmovzxwd(vmm, addr); - uni_vpslld(vmm, vmm, 16); - break; - case data_type::f16: vcvtph2ps(vmm, addr); break; - default: assert(!"unsupported wei_scales data type"); - } - }; + auto wei_scales_addr + = maybe_EVEX_compress_addr(reg_wei_scales, base_offset(n)); - load_scales(zmm_tmp, wei_scales_addr0); - if (is_k_tail) - vpxord(vmm_wei_scales1, vmm_wei_scales1, vmm_wei_scales1); - else - load_scales(zmm_wei_scales1, wei_scales_addr1); + load_scale_value(zmm_tmp, wei_scales_addr, conf_->wei_scales_dt, is_n_tail); + + uni_vmovups(vmm_wei_scales1, vmm_tmp); vinsertf64x4(vmm_wei_scales0, vmm_tmp, Ymm(vmm_wei_scales1.getIdx()), 1); vextractf64x4(Ymm(vmm_tmp.getIdx()), vmm_tmp, 1); @@ -5039,6 +5332,37 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::get_scales(const int blk, vpermd(vmm_wei_scales1, vmm_permd, vmm_wei_scales1); } +/** Loads zero points into 2 registers and permute it. +* Since groups over K-dimension are handled outside the kernel +* loading is performed for the same address for both registers. +*/ +template +void jit_brgemm_matmul_copy_b_cvt_bf16_t::get_zero_points( + const int n, const bool is_n_tail, const bool is_k_tail) { + if (!conf_->is_wei_zp_per_n) return; + + const auto zp_dt = conf_->wei_zp_dt; + + const auto base_offset = [&](int n_idx) { + const auto zp_dt_sz = types::data_type_size(zp_dt); + const auto elems_per_byte + = one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1; + return n_idx * zp_dt_sz / elems_per_byte; + }; + + const auto addr = maybe_EVEX_compress_addr(reg_wei_zp, base_offset(n)); + load_value(vmm_tmp, addr, vmm_permd, zp_dt, is_n_tail); + + uni_vmovups(vmm_zp_b_val1, vmm_tmp); + + const auto zmm_zp_b_val1 = maybe_mask(vmm_zp_b_val1, is_n_tail); + vinserti64x4(vmm_zp_b_val0, vmm_tmp, Ymm(vmm_zp_b_val1.getIdx()), 1); + vextracti64x4(Ymm(vmm_tmp.getIdx()), vmm_tmp, 1); + vinserti64x4(vmm_zp_b_val1, zmm_zp_b_val1, Ymm(vmm_tmp.getIdx()), 0); + vpermd(vmm_zp_b_val0, vmm_permd, vmm_zp_b_val0); + vpermd(vmm_zp_b_val1, vmm_permd, vmm_zp_b_val1); +} + template void jit_brgemm_matmul_copy_b_cvt_bf16_t::copy_block( const int nrows, int ncolumns, bool zeropad) { @@ -5064,17 +5388,21 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::copy_block( const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_; auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset); auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride); - load_value(src_vmm0, load_addr0, vmm_permd, conf_->orig_wei_dt); - load_value(src_vmm1, load_addr1, vmm_permd, conf_->orig_wei_dt); const bool is_n_tail = ncolumns - n < n_blk_step; const bool is_k_tail = nrows - k < k_blk_step; - get_scales(blk, k, n, is_n_tail, is_k_tail); - decompress_and_downcvt_2reg(src_vmm0, src_vmm1, vmm_zp_b_val, - vmm_zp_b_val, vmm_wei_scales0, vmm_wei_scales0, + + load_value(src_vmm0, load_addr0, vmm_permd, conf_->orig_wei_dt); + load_value(src_vmm1, load_addr1, vmm_permd, conf_->orig_wei_dt); + get_wei_scales(n, is_n_tail, is_k_tail); + get_zero_points(n, is_n_tail, is_k_tail); + decompress_and_downcvt_2reg(src_vmm0, src_vmm1, vmm_zp_b_val0, + vmm_zp_b_val1, vmm_wei_scales0, vmm_wei_scales1, conf_->orig_wei_dt, conf_->wei_dt); }; + maybe_update_strides(nrows); + int iter = 0; for_(int k = 0; k < nrows; k += k_blk_step) for (int n = 0; n < ncolumns; n += n_blk_step) { @@ -5090,6 +5418,14 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::copy_block( uni_vpxor(store_vmm, store_vmm, store_vmm); else load(blk_idx, k, n); + + // Special case for goruped zp/scales when nrows == 1 + if (is_wei_grouped_over_k_ && nrows == 1) { + save_half_block(blk_idx, store_addr); + iter++; + continue; + } + uni_vmovups(store_addr, store_vmm); iter++; @@ -5107,14 +5443,29 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::generate() { mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]); mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]); mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]); + mov(reg_wei_zp, ptr[param1 + GET_OFF(zp_b_value_ptr)]); + mov(reg_k_start, ptr[param1 + GET_OFF(current_K_start)]); - if (req_zp_b_shift_) { - mov(reg_tmp, ptr[param1 + GET_OFF(zp_b_value_ptr)]); - uni_vpbroadcastd(vmm_zp_b_val, ptr[reg_tmp]); - } + load_common_zp_value(vmm_zp_b_val0, reg_wei_zp); + load_common_zp_value(vmm_zp_b_val1, reg_wei_zp); + load_common_scale_value(vmm_wei_scales0, reg_wei_scales); + load_common_scale_value(vmm_wei_scales1, reg_wei_scales); auto compute_K_loop_body = [&](const reg64_t ®_K, int ncolumns, bool zeropad) { + // Compute special K-loop for per-k attributes + // Only when k_group_size < k_blk_step + // Otherwise default K-loop is used + if (is_wei_grouped_over_k_) { + const int k_group_size = conf_->is_wei_zp_per_k + ? conf_->wei_zp_k_gsize + : conf_->wei_scales_k_gsize; + if (k_group_size == 1) { + if (zeropad) return; + copy_block(k_group_size, ncolumns, /*zeropad= */ false); + return; + } + } const int k_unroll = 8; Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done; @@ -5125,8 +5476,6 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::generate() { copy_block(k_unroll * k_blk_step, ncolumns, zeropad); add(reg_src, k_unroll * src_stride_); add(reg_tr_src, k_unroll * tr_src_stride_); - if (req_apply_wei_scales_) - add(reg_wei_scales, k_unroll * k_blk_step * wei_scales_N_stride_); sub(reg_K, k_unroll * k_blk_step); cmp(reg_K, k_unroll * k_blk_step); @@ -5139,8 +5488,6 @@ void jit_brgemm_matmul_copy_b_cvt_bf16_t::generate() { copy_block(k_blk_step, ncolumns, zeropad); add(reg_src, src_stride_); add(reg_tr_src, tr_src_stride_); - if (req_apply_wei_scales_) - add(reg_wei_scales, k_blk_step * wei_scales_N_stride_); sub(reg_K, k_blk_step); jmp(K_loop_single, T_NEAR); diff --git a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.hpp index 27295bc2b8..14051833f9 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_copy_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_copy_utils.hpp @@ -59,7 +59,7 @@ struct jit_brgemm_matmul_copy_a_t { const void *tr_src; const void *zp_b_compensation_buffer_ptr; const void *zp_a_compensation_result_ptr; - const void *zp_b_neg_value_ptr; + const void *zp_b_neg_val_ptr; const void *zp_ab_comp_ptr; dim_t current_K_start; diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index 6aa656b6d2..b8397d865c 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -1296,6 +1296,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.dst_dt = dst_d.data_type(); bgmmc.wei_dt = weights_d.data_type(); bgmmc.orig_wei_dt = weights_d.data_type(); + bgmmc.wei_zp_dt = attr.zero_points_.get(DNNL_ARG_WEIGHTS).get_data_type(); bgmmc.with_reduce = mmd.reduce_desc.format_kind != format_kind::undef; bgmmc.reduce_dt @@ -1396,19 +1397,19 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.with_src_scales = !src_scales.has_default_values(); bgmmc.with_wei_scales = !wei_scales.has_default_values(); if (bgmmc.with_wei_scales) { - const auto wei_qmask_N = 1 << (bgmmc.ndims - 1); - const auto wei_qmask_K = 1 << (bgmmc.ndims - 2); - bgmmc.is_wei_scale_per_k = wei_scales.get_mask() & wei_qmask_K; - bgmmc.is_wei_scale_per_n = wei_scales.get_mask() & wei_qmask_N; + const auto &wei_scale_mask = wei_scales.get_mask(); + bgmmc.is_wei_scale_common = wei_scale_mask == 0; + bgmmc.is_wei_scale_per_k = wei_scale_mask & 1 << (bgmmc.ndims - 2); + bgmmc.is_wei_scale_per_n = wei_scale_mask & 1 << (bgmmc.ndims - 1); bgmmc.apply_scales_in_buffer_b = bgmmc.is_wei_scale_per_k && bgmmc.with_wei_decompression && bgmmc.N * bgmmc.K != 1; bgmmc.wei_scales_dt = wei_scales.get_data_type(); bgmmc.wei_scales_dt_sz = types::data_type_size(bgmmc.wei_scales_dt); - bgmmc.wei_scales_k_group_size = wei_scales.get_group(0); + bgmmc.wei_scales_k_gsize = wei_scales.get_group(0); // only common and per-oc-channel scales are supported // only per-ic-channel scales is supprted with weight decompression - VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_wei_scale_per_n + VCONDCHECK_BG(bgmmc.is_wei_scale_common || bgmmc.is_wei_scale_per_n || IMPLICATION(bgmmc.is_wei_scale_per_k, bgmmc.with_wei_decompression), VERBOSE_UNSUPPORTED_SCALES_CFG); @@ -1420,6 +1421,28 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0), VERBOSE_UNSUPPORTED_SCALES_CFG); + const auto &wei_zp = attr.zero_points_.get(DNNL_ARG_WEIGHTS); + const auto has_wei_zp = !wei_zp.has_default_values(); + + if (has_wei_zp) { + const auto wei_zp_mask = wei_zp.get_mask(); + bgmmc.is_wei_zp_common = wei_zp_mask == 0; + bgmmc.is_wei_zp_per_k = wei_zp_mask & (1 << (bgmmc.ndims - 2)); + bgmmc.is_wei_zp_per_n = wei_zp_mask & (1 << (bgmmc.ndims - 1)); + bgmmc.wei_zp_dt = wei_zp.get_data_type(); + bgmmc.wei_zp_k_gsize = wei_zp.get_group(0); + + VCONDCHECK_BG(wei_zp_mask == 0 || bgmmc.is_wei_zp_per_k + || bgmmc.is_wei_zp_per_n, + VERBOSE_UNSUPPORTED_ZP_CFG); + + // Check if K groups for scales and for zero points are identical + VCONDCHECK_BG( + IMPLICATION(bgmmc.is_wei_zp_per_k && bgmmc.is_wei_scale_per_k, + bgmmc.wei_zp_k_gsize == bgmmc.wei_scales_k_gsize), + VERBOSE_UNSUPPORTED_ZP_CFG); + } + const auto &p = attr.post_ops_; bgmmc.with_sum = p.find(primitive_kind::sum) != -1; const int eltwise_ind = p.find(primitive_kind::eltwise); @@ -1532,9 +1555,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, bgmmc.transposed_B = bm_conf_utils.check_is_transposed(bgmmc.wei_tag) || bgmmc.wei_tag == adbc; bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b(); - bgmmc.req_transpose_scales = bgmmc.apply_scales_in_buffer_b - && bgmmc.is_wei_scale_per_k && bgmmc.is_wei_scale_per_n - && bgmmc.transposed_B; if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16()) && is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) { @@ -1719,14 +1739,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc, = bm_conf_utils.wei_down_convert_to_vnni(); } - // This setting must be updated post blocking as it has a dependency on - // `bgmmc.K_blk`. See `gK_and_K_blk_are_divisible` comment. - if (bgmmc.is_wei_scale_per_k) { - const auto gK = bgmmc.wei_scales_k_group_size; - bgmmc.gK_and_K_blk_are_divisible = gK > 1 - && ((bgmmc.K_blk % gK == 0) || (gK % bgmmc.K_blk == 0)); - } - VCHECK_BG(bm_conf_utils.set_B_flags(weights_md), VERBOSE_BLOCKING_FAIL, ""); bgmmc.M_tail = bgmmc.is_runtime_M ? 0 : bgmmc.M % bgmmc.M_blk; diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp index a58ae2d7d9..e2d86eaeff 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.hpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.hpp @@ -136,6 +136,7 @@ struct brgemm_matmul_conf_t { data_type_t reduce_dt; data_type_t orig_src_dt; data_type_t orig_wei_dt; + int nthr; int nthr_k = 1, nthr_m = 1, nthr_n = 1, nthr_b = 1; @@ -227,24 +228,27 @@ struct brgemm_matmul_conf_t { bool is_runtime_M = false; bool is_runtime_N = false; bool is_runtime_K = false; + bool extendable_k = false; bool is_src_batch_layout_trivial = false; bool is_wei_batch_layout_trivial = false; bool is_dst_batch_layout_trivial = false; + + // Attributes related to quantization + // Scales + bool apply_scales_in_buffer_b = false; + size_t wei_scales_dt_sz = 0; bool is_wei_scale_per_n = false; bool is_wei_scale_per_k = false; - bool req_transpose_scales = false; - bool apply_scales_in_buffer_b = false; - // For generic cases, when groups are selected the way they can't divide a - // K_blk in equal pieces, it gets really hard to call a kernel with a - // single "per_N line" of scales. In this case weights will be copied - // to a larger memory buffer and used like a full tensor. - // TODO: convert to a method. State must be set in a specific place of the - // initialization as relies on blocking. - bool gK_and_K_blk_are_divisible = false; - size_t wei_scales_dt_sz = 0; - dim_t wei_scales_k_group_size = 0; + bool is_wei_scale_common = false; + dim_t wei_scales_k_gsize = 0; data_type_t wei_scales_dt = data_type::undef; - bool extendable_k = false; + + // Zero points + dim_t wei_zp_k_gsize = 0; + bool is_wei_zp_per_k = false; + bool is_wei_zp_per_n = false; + bool is_wei_zp_common = false; + data_type_t wei_zp_dt = data_type::undef; bool is_gemv = false; diff --git a/tests/benchdnn/inputs/matmul/harness_matmul_decompression b/tests/benchdnn/inputs/matmul/harness_matmul_decompression index 1bb91ddba0..3f2098705f 100644 --- a/tests/benchdnn/inputs/matmul/harness_matmul_decompression +++ b/tests/benchdnn/inputs/matmul/harness_matmul_decompression @@ -319,8 +319,7 @@ --dt=f32:s4:f32,f32:u4:f32 --wtag=any,ab,ba --attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1 ---attr-zero-points=,wei:common:2 -# ,wei:per_oc:s4:4096x1 # PER_OC skips for all matmul implemetations +--attr-zero-points=,wei:common:2,wei:per_oc:s4 --attr-fpmath=strict:true 1x4096:4096x4096 @@ -328,7 +327,32 @@ --dt=f32:s8:f32,f32:u8:f32 --wtag=any,ab,ba --attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1 ---attr-zero-points=,wei:common:2 -# ,wei:per_ocic:s8:32x1 # PER_OC skips for all matmul implementations +--attr-zero-points=,wei:common:2,wei:per_ocic:s8:32x1 --attr-fpmath=strict:true 1x4096:4096x4096 + + +## Additional grouped scales/ZP testing +--reset +--dt=f16:s4:f16 +--wtag=abc,acb +--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1 +--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1 +--attr-fpmath=f16:true +12x4x576:12x576x192 + +--reset +--dt=bf16:s4:bf16 +--wtag=abc,acb +--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1 +--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1 +--attr-fpmath=bf16:true +12x4x576:12x576x192 + +--reset +--dt=f32:s4:f32 +--wtag=abc,acb +--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1 +--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1 +--attr-fpmath=strict:true +12x4x576:12x576x192