mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 10:03:50 +08:00
x64: matmul: Enable grouped ZP for per_oc/per_ocic & rework scales
This commit is contained in:
committed by
Dmitriy Ovchinnikov
parent
051b020bb1
commit
0ba6d85080
@ -508,13 +508,13 @@ public:
|
||||
movups(addr, x);
|
||||
}
|
||||
|
||||
void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Address &addr) {
|
||||
void uni_vmovdqu16(const Xbyak::Xmm &x, const Xbyak::Operand &op) {
|
||||
if (is_valid_isa(avx512_core))
|
||||
vmovdqu16(x, addr);
|
||||
vmovdqu16(x, op);
|
||||
else if (is_valid_isa(avx))
|
||||
vmovups(x, addr);
|
||||
vmovups(x, op);
|
||||
else
|
||||
movups(x, addr);
|
||||
movups(x, op);
|
||||
}
|
||||
|
||||
void uni_vmovups(const Xbyak::Address &addr, const Xbyak::Xmm &x) {
|
||||
|
@ -1055,8 +1055,8 @@ void matmul_amx_blocking_params_micro_t::set_blocking_parameters(
|
||||
|
||||
// TODO: review extendable_k_ condition to cover more cases
|
||||
extendable_k_ = (K % wei_k_blk != 0) && (brgemm_k_elems > wei_k_blk)
|
||||
&& wei_zp_type == none && !use_buffer_a
|
||||
&& !packed_sparse_weights && current_lda_ == K;
|
||||
&& wei_zp_type == none && !apply_scales_in_buffer_b
|
||||
&& !use_buffer_a && !packed_sparse_weights && current_lda_ == K;
|
||||
|
||||
if (extendable_k_) {
|
||||
if (brgemm_k_elems >= K) {
|
||||
|
@ -252,13 +252,20 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
auto check_attr_zero_points = [&]() -> bool {
|
||||
const auto &zp = attr()->zero_points_;
|
||||
static const std::vector<int> supported_args {
|
||||
DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
|
||||
DNNL_ARG_SRC, DNNL_ARG_DST};
|
||||
for (int arg : supported_args) {
|
||||
if (!zp.has_default_values(arg)) {
|
||||
const int mask = zp.get_mask(arg);
|
||||
if (mask > 0) return false;
|
||||
}
|
||||
}
|
||||
if (!zp.has_default_values(DNNL_ARG_WEIGHTS)) {
|
||||
const auto mask = zp.get_mask(DNNL_ARG_WEIGHTS);
|
||||
const auto kn_mask = wei_qmask_N() + wei_qmask_K();
|
||||
const bool zp_over_batch = (mask & kn_mask) != mask;
|
||||
const bool mask_ok = (mask & ~kn_mask) == 0;
|
||||
return !(zp_over_batch && batch() > 1) && mask_ok;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
const bool problem_dt_correct = one_of(true, is_f4, is_int8, is_f8, is_bf16,
|
||||
@ -286,6 +293,7 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
| primitive_attr_t::skip_mask_t::scales_groups
|
||||
| primitive_attr_t::skip_mask_t::
|
||||
zero_points_data_type
|
||||
| primitive_attr_t::skip_mask_t::zero_points_groups
|
||||
| primitive_attr_t::skip_mask_t::post_ops
|
||||
| primitive_attr_t::skip_mask_t::sum_dt
|
||||
| primitive_attr_t::skip_mask_t::fpmath_mode,
|
||||
@ -428,11 +436,6 @@ status_t brgemm_matmul_t<isa>::pd_t::init(engine_t *engine) {
|
||||
|
||||
auto scratchpad = scratchpad_registry().registrar();
|
||||
init_scratchpad(scratchpad, bgmmc_);
|
||||
const auto wei_scale_count = bgmmc_.is_wei_scale_per_k
|
||||
? (bgmmc_.is_wei_scale_per_n ? N() * K() : K())
|
||||
: N();
|
||||
book_precomputed_scales(scratchpad, attr()->scales_, wei_scale_count,
|
||||
/* scale_adjust_factor = */ 1.f, bgmmc_.req_transpose_scales);
|
||||
|
||||
return status::success;
|
||||
}
|
||||
@ -1189,9 +1192,9 @@ void brgemm_matmul_t<isa>::copy_a_chunk_in_buffer(
|
||||
ctx.zp_a_compensation_result_ptr
|
||||
= (void *)brgmm_ctx.get_zp_b_compensation_result_ptr(
|
||||
ithr, m_blk_idx);
|
||||
ctx.zp_b_neg_value_ptr = (void *)brgmm_ctx.get_zp_b_neg_val_ptr();
|
||||
ctx.zp_ab_comp_ptr = (void *)brgmm_ctx.get_zp_ab_mixed_comp_ptr();
|
||||
ctx.dynamic_src_ld = brgmm_ctx.get_src_stride();
|
||||
ctx.zp_b_neg_val_ptr = brgmm_ctx.get_wei_zp_neg_ptr();
|
||||
|
||||
for (int gb = 0; gb < gemm_batch_iters; gb++) {
|
||||
const int k = k_start + gb * bgmmc.K_blk;
|
||||
@ -1251,46 +1254,29 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
|
||||
ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr(
|
||||
ithr, b_idx, n_blk_idx);
|
||||
ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr();
|
||||
ctx.zp_b_value_ptr = (void *)brgmm_ctx.get_zp_b_val_ptr();
|
||||
ctx.compensation_ptr
|
||||
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
|
||||
|
||||
ctx.dynamic_src_stride = brgmm_ctx.copy_B_wei_stride();
|
||||
|
||||
// For best performance, scales should be taken in copy kernels as is.
|
||||
// The biggest challenge is grouped scales (`wei_scales_gK > 1` == true,
|
||||
// (`bgmmc.is_wei_scale_per_k` == true) as `gK` starts interfering with
|
||||
// `K_blk`. In case when `bgmmc.gK_and_K_blk_are_divisible` == true, it's
|
||||
// easy to process a full `K_blk` piece or a portion of it by applying a
|
||||
// single scale group. When doing a portion of B, it requires an additional
|
||||
// loop over a single block, which `wei_scales_n_g` is responsible for.
|
||||
//
|
||||
// In other case, weights scales would copy data in a full-size of B
|
||||
// scratchpad tensor and apply scales per point.
|
||||
const auto wei_scales_gK = bgmmc.wei_scales_k_group_size;
|
||||
const auto apply_single_scales_group = bgmmc.is_wei_scale_per_k
|
||||
&& bgmmc.gK_and_K_blk_are_divisible && wei_scales_gK > 1;
|
||||
|
||||
// `div_up` covers the case when K_blk is less than gK.
|
||||
const auto wei_scales_n_g = apply_single_scales_group
|
||||
? div_up(bgmmc.K_blk, wei_scales_gK)
|
||||
: 1;
|
||||
|
||||
for_(int gb = 0; gb < gemm_batch; gb++)
|
||||
for (int k_i = 0; k_i < wei_scales_n_g; k_i++) {
|
||||
const int k = k_start + gb * bgmmc.K_blk + k_i * wei_scales_gK;
|
||||
// For the grouped Zero points/scales need to vary k-block size
|
||||
// For this case need to call copy kernel with unaligned (k, k_iters)
|
||||
auto call_copy_kernel = [&](int k, int k_iters, int gb,
|
||||
bool aligned_blocks = false) {
|
||||
ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n);
|
||||
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
|
||||
ithr, k_blk_idx, n_blk_idx, gb, k_i * wei_scales_gK);
|
||||
ctx.compensation_ptr
|
||||
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
|
||||
ctx.current_K_start = k;
|
||||
ctx.current_K_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
|
||||
if (apply_single_scales_group) {
|
||||
// Update the copy_B K_block to process when apply a single scale.
|
||||
ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters);
|
||||
}
|
||||
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);
|
||||
// Use k for buffer locating only when the block is unaligned
|
||||
if (aligned_blocks)
|
||||
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
|
||||
ithr, k_blk_idx, n_blk_idx, gb);
|
||||
else
|
||||
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_k_ptr(ithr, k);
|
||||
|
||||
ctx.current_K_start = k;
|
||||
ctx.current_K_iters = k_iters;
|
||||
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(k_iters);
|
||||
ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr();
|
||||
ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k);
|
||||
ctx.zp_b_value_ptr = brgmm_ctx.get_wei_zp_ptr(n, k);
|
||||
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
|
||||
&& isa == avx512_core_fp16) {
|
||||
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
|
||||
@ -1298,37 +1284,52 @@ void brgemm_matmul_t<isa>::copy_b_chunk_in_buffer(
|
||||
} else {
|
||||
(*copy_B_kernel_)(&ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if (!is_K_tail) return;
|
||||
// grouped zero points &-or scales
|
||||
if (bgmmc.is_wei_zp_per_k || bgmmc.is_wei_scale_per_k) {
|
||||
const auto &k_group = bgmmc.is_wei_zp_per_k ? bgmmc.wei_zp_k_gsize
|
||||
: bgmmc.wei_scales_k_gsize;
|
||||
const auto brgemm_k_blk = nstl::min(bgmmc.K, bgmmc.K_blk);
|
||||
const auto adj_k_blk = nstl::min(brgemm_k_blk, k_group);
|
||||
assert(adj_k_blk > 0);
|
||||
auto k = k_start;
|
||||
// is_K_tail behaves incorrectly for the case K < K_blk
|
||||
// Should be is_K_tail = true && gemm_batch = 0
|
||||
// Now: is_K_tail = false, gemm_batch = 1.
|
||||
// Causes Segfault when the `group over k` < K_blk for blocked formats.
|
||||
const auto work_amount = bgmmc.K < bgmmc.K_blk
|
||||
? bgmmc.K
|
||||
: gemm_batch * bgmmc.K_blk
|
||||
+ is_K_tail * (bgmmc.K % bgmmc.K_blk);
|
||||
const auto k_end = k_start + work_amount;
|
||||
|
||||
// `div_up` covers the case when K_blk is less than gK.
|
||||
const auto wei_scales_n_g_tail = apply_single_scales_group
|
||||
? div_up(bgmmc.K % bgmmc.K_blk, wei_scales_gK)
|
||||
: 1;
|
||||
|
||||
for (int k_i = 0; k_i < wei_scales_n_g_tail; k_i++) {
|
||||
const int k = k_start + gemm_batch * bgmmc.K_blk + k_i * wei_scales_gK;
|
||||
ctx.src = (void *)brgmm_ctx.get_data_B_kn_ptr(B_data_batch_ptr, k, n);
|
||||
ctx.tr_src = (void *)brgmm_ctx.get_buf_B_ptr(
|
||||
ithr, k_blk_idx, n_blk_idx, gemm_batch, k_i * wei_scales_gK);
|
||||
ctx.compensation_ptr
|
||||
= (void *)brgmm_ctx.get_s8s8_comp_ptr(ithr, b_idx, n_blk_idx);
|
||||
ctx.current_K_start = k;
|
||||
ctx.current_K_iters = bgmmc.K % bgmmc.K_blk;
|
||||
if (apply_single_scales_group) {
|
||||
// Update the copy_B K_block to process when apply a single scale.
|
||||
ctx.current_K_iters = nstl::min(wei_scales_gK, ctx.current_K_iters);
|
||||
// Handle first block
|
||||
if (k_start % adj_k_blk > 0) {
|
||||
const auto first_blk_size = adj_k_blk - (k_start % adj_k_blk);
|
||||
call_copy_kernel(k_start, first_blk_size, 0);
|
||||
k += first_blk_size;
|
||||
}
|
||||
ctx.current_K_pad = brgmm_ctx.get_current_K_pad(ctx.current_K_iters);
|
||||
ctx.src_scales_ptr = brgmm_ctx.get_src_scales_ptr();
|
||||
ctx.wei_scales_ptr = brgmm_ctx.get_wei_scales_ptr(n, k);
|
||||
if (bgmmc.blocked_B && !bgmmc.is_f16_with_int_wei
|
||||
&& isa == avx512_core_fp16) {
|
||||
cvt_float16_to_float((float *)ctx.tr_src, (float16_t *)ctx.src,
|
||||
bgmmc.wei_n_blk * ctx.current_K_iters);
|
||||
} else {
|
||||
(*copy_B_kernel_)(&ctx);
|
||||
// Handle full blocks
|
||||
for (; (k + adj_k_blk) <= k_end; k += adj_k_blk) {
|
||||
const auto gb = (k - k_start) / bgmmc.K_blk;
|
||||
call_copy_kernel(k, adj_k_blk, gb);
|
||||
}
|
||||
// Handle last block
|
||||
if (k_end > k) {
|
||||
const auto gb = (k - k_start) / bgmmc.K_blk;
|
||||
call_copy_kernel(k, k_end - k, gb);
|
||||
}
|
||||
} else { // Default case with k_blk blocking
|
||||
for (int gb = 0; gb < gemm_batch; ++gb) {
|
||||
const auto k = k_start + gb * bgmmc.K_blk;
|
||||
const auto k_iters = nstl::min(bgmmc.K_blk, bgmmc.K);
|
||||
call_copy_kernel(k, k_iters, gb, /*aligned_blocks=*/true);
|
||||
}
|
||||
if (is_K_tail) {
|
||||
const auto k = k_start + gemm_batch * bgmmc.K_blk;
|
||||
const auto k_iters = bgmmc.K % bgmmc.K_blk;
|
||||
call_copy_kernel(k, k_iters, gemm_batch, /*aligned_blocks=*/true);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1373,7 +1374,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
// setup scales / zp pointers
|
||||
const void *src_zero_points = CTX_IN_MEM(
|
||||
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC);
|
||||
const void *wei_zero_points = CTX_IN_MEM(
|
||||
wei_zp_ptr_ = CTX_IN_MEM(
|
||||
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_WEIGHTS);
|
||||
const void *dst_zero_points = CTX_IN_MEM(
|
||||
const void *, DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST);
|
||||
@ -1383,18 +1384,18 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
pd->attr()->zero_points_.get_data_type(DNNL_ARG_SRC),
|
||||
src_zero_points, 0)
|
||||
: 0;
|
||||
zero_point_b_val_ = wei_zero_points ? cpu::io::load_int_value(
|
||||
pd->attr()->zero_points_.get_data_type(
|
||||
DNNL_ARG_WEIGHTS),
|
||||
wei_zero_points, 0)
|
||||
: 0;
|
||||
zero_point_b_negative_val_ = -zero_point_b_val_;
|
||||
zero_point_c_val_ = dst_zero_points
|
||||
? cpu::io::load_int_value(
|
||||
pd->attr()->zero_points_.get_data_type(DNNL_ARG_DST),
|
||||
dst_zero_points, 0)
|
||||
: 0;
|
||||
|
||||
wei_zp_neg_val_ = (-1)
|
||||
* (wei_zp_ptr_ ? cpu::io::load_int_value(
|
||||
pd->attr()->zero_points_.get_data_type(
|
||||
DNNL_ARG_WEIGHTS),
|
||||
wei_zp_ptr_, 0)
|
||||
: 0);
|
||||
memory_tracking::grantor_t scratchpad = ctx.get_scratchpad_grantor();
|
||||
|
||||
const auto &bgmmc = pd->get_brgemm_matmul_conf();
|
||||
@ -1403,10 +1404,6 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC);
|
||||
wei_scales_ = CTX_IN_MEM(
|
||||
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS);
|
||||
wei_scales_tr_ = bgmmc_.is_wei_scale_per_k
|
||||
&& !bgmmc_.gK_and_K_blk_are_divisible
|
||||
? scratchpad.template get<float>(key_precomputed_scales)
|
||||
: nullptr;
|
||||
dst_scales_ = CTX_IN_MEM(
|
||||
const float *, DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST);
|
||||
dst_scales_inv_ = scratchpad.template get<float>(key_matmul_dst_scales);
|
||||
@ -1887,17 +1884,36 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
// a block of memory per thread.
|
||||
// `gb` defines an offset to a specific block over K inside a portion of
|
||||
// blocks.
|
||||
// `k` defines an offset inside a specific block over K. It is the
|
||||
// smallest granularity possible. It is used only when wei_decompression
|
||||
// feature is requested with a single scale over a sub-piece of `K_blk`.
|
||||
char *get_buf_B_ptr(
|
||||
int ithr, int k_blk_idx, int n_blk_idx, int gb, int k = 0) const {
|
||||
char *get_buf_B_ptr(int ithr, int k_blk_idx, int n_blk_idx, int gb) const {
|
||||
UNUSED(n_blk_idx);
|
||||
if (!bgmmc_.use_buffer_b) return nullptr;
|
||||
int k_blk_local = k_blk_idx % get_K_chunk_size();
|
||||
return buf_B_ptr_ + ithr * bgmmc_.buffer_b_per_thread_sz
|
||||
const auto offset = ithr * bgmmc_.buffer_b_per_thread_sz
|
||||
+ k_blk_local * bgmmc_.buffer_b_k_brg_stride
|
||||
+ gb * bgmmc_.buffer_b_gb_stride + k * bgmmc_.buffer_b_k_stride;
|
||||
+ gb * bgmmc_.buffer_b_gb_stride;
|
||||
return buf_B_ptr_ + offset;
|
||||
}
|
||||
|
||||
/* Returns a pointer to buffer B based on unaligned K inside
|
||||
* a thread buffer. Used for copy kernels including grouped ZP/Scales.
|
||||
* For the vnni granularity > 1 it returns a pointer to a start of vnni block.
|
||||
* Functionality intersects with get_buf_B_ptr(). TODO: make combined solution.
|
||||
*/
|
||||
char *get_buf_B_k_ptr(const int ithr, const int k) const {
|
||||
if (!bgmmc_.use_buffer_b) return nullptr;
|
||||
|
||||
const int batch_block_size = bgmmc_.K_blk * bgmmc_.brgemm_batch_size;
|
||||
const auto batch_blocking = std::div(k, batch_block_size);
|
||||
const auto k_blk_idx = batch_blocking.quot;
|
||||
const auto k_blk_local = k_blk_idx % get_K_chunk_size();
|
||||
const auto k_in_batch = batch_blocking.rem;
|
||||
|
||||
auto offset = ithr * bgmmc_.buffer_b_per_thread_sz;
|
||||
offset += k_blk_local * bgmmc_.buffer_b_k_brg_stride;
|
||||
// div down to the start of the vnni block
|
||||
const auto k_outer = (k_in_batch / vnni_factor) * vnni_factor;
|
||||
offset += k_outer * bgmmc_.buffer_b_k_stride;
|
||||
return buf_B_ptr_ + offset;
|
||||
}
|
||||
|
||||
char *get_buf_C_ptr(int ithr, int m_blk_idx, int n_blk_idx) const {
|
||||
@ -2049,107 +2065,17 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
|
||||
// Returns a pointer to the weights scales for the correspondent block based
|
||||
// on @p n and @p k.
|
||||
//
|
||||
// For grouped scales it also prepares a scratchpad buffer when `K_blk` and
|
||||
// `group_size` don't divide each other.
|
||||
const void *get_wei_scales_ptr(int n, int k = 0) const {
|
||||
if (!wei_scales_) return wei_scales_;
|
||||
|
||||
const auto gK = bgmmc_.wei_scales_k_group_size;
|
||||
const auto in_stride_n = bgmmc_.is_wei_scale_per_n ? 1 : 0;
|
||||
const auto in_stride_k = bgmmc_.is_wei_scale_per_k
|
||||
? (bgmmc_.is_wei_scale_per_n ? bgmmc_.N : 1)
|
||||
: 0;
|
||||
int k_g = gK > 1 ? k / gK : k;
|
||||
const auto offset = n * in_stride_n + k_g * in_stride_k;
|
||||
const auto *ptr = reinterpret_cast<const char *const>(wei_scales_)
|
||||
+ offset * bgmmc_.wei_scales_dt_sz;
|
||||
if (bgmmc_.gK_and_K_blk_are_divisible || !bgmmc_.is_wei_scale_per_k)
|
||||
return ptr;
|
||||
|
||||
if (bgmmc_.req_transpose_scales) {
|
||||
// Transpose case covers grouped and non-grouped scenarios.
|
||||
const auto in_tr_stride_n = bgmmc_.is_wei_scale_per_n
|
||||
? (bgmmc_.is_wei_scale_per_k ? bgmmc_.K : 1)
|
||||
: 0;
|
||||
const auto in_tr_stride_k = bgmmc_.is_wei_scale_per_k ? 1 : 0;
|
||||
const auto offset_tr = k * in_tr_stride_k + n * in_tr_stride_n;
|
||||
auto *ptr_tr = reinterpret_cast<char *const>(wei_scales_tr_)
|
||||
+ offset_tr * bgmmc_.wei_scales_dt_sz;
|
||||
|
||||
// `kk_offset` covers a situation when
|
||||
// `!bgmmc_.gK_and_K_blk_are_divisible`. Need to process the start
|
||||
// of the next K block in a special way.
|
||||
//
|
||||
// Process it separately for simpler main body code.
|
||||
const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0;
|
||||
if (kk_offset) {
|
||||
for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk);
|
||||
nn++)
|
||||
for (int kk = 0; kk < kk_offset; kk++) {
|
||||
const auto in_idx = nn * in_stride_n;
|
||||
const auto out_idx = nn * bgmmc_.K + kk;
|
||||
const float wei_scales_val = cpu::io::load_float_value(
|
||||
bgmmc_.wei_scales_dt, ptr, in_idx);
|
||||
cpu::io::store_float_value(bgmmc_.wei_scales_dt,
|
||||
wei_scales_val, ptr_tr, out_idx);
|
||||
}
|
||||
}
|
||||
|
||||
for_(int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++)
|
||||
for (int kk = 0;
|
||||
kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset;
|
||||
kk++) {
|
||||
const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k;
|
||||
const auto out_idx = nn * bgmmc_.K + kk;
|
||||
const auto ptr_kk_offset = bool(kk_offset) * in_stride_k
|
||||
* bgmmc_.wei_scales_dt_sz;
|
||||
const auto ptr_tr_kk_offset
|
||||
= kk_offset * in_tr_stride_k * bgmmc_.wei_scales_dt_sz;
|
||||
const float wei_scales_val = cpu::io::load_float_value(
|
||||
bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx);
|
||||
cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val,
|
||||
ptr_tr + ptr_tr_kk_offset, out_idx);
|
||||
}
|
||||
|
||||
return ptr_tr;
|
||||
} else {
|
||||
const auto offset_non_g = n * in_stride_n + k * in_stride_k;
|
||||
auto *ptr_non_g = reinterpret_cast<char *const>(wei_scales_tr_)
|
||||
+ offset_non_g * bgmmc_.wei_scales_dt_sz;
|
||||
|
||||
const int kk_offset = (k % gK != 0) ? (gK - (k % gK)) : 0;
|
||||
if (kk_offset) {
|
||||
for_(int kk = 0; kk < kk_offset; kk++)
|
||||
for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk);
|
||||
nn++) {
|
||||
const auto in_idx = nn * in_stride_n;
|
||||
const auto out_idx = nn * in_stride_n + kk * in_stride_k;
|
||||
const float wei_scales_val = cpu::io::load_float_value(
|
||||
bgmmc_.wei_scales_dt, ptr, in_idx);
|
||||
cpu::io::store_float_value(bgmmc_.wei_scales_dt,
|
||||
wei_scales_val, ptr_non_g, out_idx);
|
||||
}
|
||||
}
|
||||
|
||||
for_(int kk = 0;
|
||||
kk < nstl::min(bgmmc_.K - k, bgmmc_.K_blk) - kk_offset;
|
||||
kk++)
|
||||
for (int nn = 0; nn < nstl::min(bgmmc_.N - n, bgmmc_.N_blk); nn++) {
|
||||
const auto in_idx = nn * in_stride_n + (kk / gK) * in_stride_k;
|
||||
const auto out_idx = nn * in_stride_n + kk * in_stride_k;
|
||||
const auto ptr_kk_offset = bool(kk_offset) * in_stride_k
|
||||
* bgmmc_.wei_scales_dt_sz;
|
||||
const auto ptr_non_g_kk_offset
|
||||
= kk_offset * in_stride_k * bgmmc_.wei_scales_dt_sz;
|
||||
const float wei_scales_val = cpu::io::load_float_value(
|
||||
bgmmc_.wei_scales_dt, ptr + ptr_kk_offset, in_idx);
|
||||
cpu::io::store_float_value(bgmmc_.wei_scales_dt, wei_scales_val,
|
||||
ptr_non_g + ptr_non_g_kk_offset, out_idx);
|
||||
}
|
||||
|
||||
return ptr_non_g;
|
||||
if (bgmmc_.is_wei_scale_common) return wei_scales_;
|
||||
auto offset = n;
|
||||
if (bgmmc_.is_wei_scale_per_k) {
|
||||
const auto &k_group_sz = bgmmc_.wei_scales_k_gsize;
|
||||
const auto k_idx = k / k_group_sz;
|
||||
offset += k_idx * bgmmc_.N;
|
||||
}
|
||||
|
||||
offset = offset * bgmmc_.wei_scales_dt_sz;
|
||||
return ((char *)wei_scales_ + offset);
|
||||
}
|
||||
|
||||
const void *get_dst_scales_ptr() const { return dst_scales_; }
|
||||
@ -2167,11 +2093,28 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
return &zero_point_a_negative_val_;
|
||||
}
|
||||
|
||||
const int32_t *get_zp_b_neg_val_ptr() const {
|
||||
return &zero_point_b_negative_val_;
|
||||
}
|
||||
const void *get_wei_zp_neg_ptr() const { return &wei_zp_neg_val_; }
|
||||
|
||||
const int32_t *get_zp_b_val_ptr() const { return &zero_point_b_val_; }
|
||||
const void *get_wei_zp_ptr(int n, int k = 0) const {
|
||||
if (!bgmmc_.has_zero_point_b) return nullptr;
|
||||
if (bgmmc_.is_wei_zp_common)
|
||||
return wei_zp_ptr_; // single zero point value
|
||||
// Locate the group based on (n,k)
|
||||
auto offset = n;
|
||||
|
||||
if (bgmmc_.is_wei_zp_per_k) {
|
||||
const auto &k_group_sz = bgmmc_.wei_zp_k_gsize;
|
||||
const auto k_idx = k / k_group_sz;
|
||||
offset += k_idx * bgmmc_.N;
|
||||
}
|
||||
|
||||
const auto dt_sz = types::data_type_size(bgmmc_.wei_zp_dt);
|
||||
const auto elems_per_byte
|
||||
= one_of(bgmmc_.wei_zp_dt, data_type::s4, data_type::u4) ? 2
|
||||
: 1;
|
||||
offset = offset * dt_sz / elems_per_byte;
|
||||
return (char *)wei_zp_ptr_ + offset;
|
||||
}
|
||||
|
||||
const int32_t *get_zp_ab_mixed_comp_ptr() const {
|
||||
return &zero_point_mixed_ab_compensation_component_;
|
||||
@ -2457,6 +2400,7 @@ struct brgemm_matmul_t<isa>::brg_matmul_exec_ctx_t {
|
||||
bool packed_sparse_weights() const { return bgmmc_.packed_sparse_weights; }
|
||||
|
||||
int get_current_K_pad(int current_K_iters) const {
|
||||
if (bgmmc_.is_wei_zp_per_k || bgmmc_.is_wei_scale_per_k) return 0;
|
||||
if (current_K_iters % bgmmc_.wei_k_blk == 0) return 0;
|
||||
return (bgmmc_.extendable_k || bgmmc_.use_fused_copy_a)
|
||||
? bgmmc_.wei_k_blk
|
||||
@ -2514,14 +2458,7 @@ private:
|
||||
const char *bias_ptr_;
|
||||
const void *src_scales_;
|
||||
const void *wei_scales_;
|
||||
// This pointer is coming from scratchpad and is needed to expand (K/g)xN
|
||||
// scales to KxN scales as copy_B kernels rely on the full register scales
|
||||
// for weights decompression feature in case when K_blk is not divisible by
|
||||
// the scales group size or vice versa.
|
||||
// TODO: implement the logic at calling copy routines spot to handle the
|
||||
// unsupported scenario mentioned above to avoid the need in this pointer
|
||||
// and the overhead around filling that big buffer.
|
||||
void *wei_scales_tr_;
|
||||
|
||||
const void *dst_scales_;
|
||||
const void *dst_scales_inv_;
|
||||
int32_t *s8s8_compensation_ptr_;
|
||||
@ -2531,10 +2468,10 @@ private:
|
||||
int32_t *reorder_zp_a_comp_ptr_;
|
||||
|
||||
int32_t zero_point_a_negative_val_;
|
||||
int32_t zero_point_b_val_;
|
||||
int32_t zero_point_b_negative_val_;
|
||||
int32_t zero_point_mixed_ab_compensation_component_;
|
||||
int32_t zero_point_c_val_;
|
||||
int32_t wei_zp_neg_val_;
|
||||
const void *wei_zp_ptr_;
|
||||
std::vector<const void *> post_ops_binary_rhs_arg_vec_;
|
||||
|
||||
int base_brg_ker_idx_;
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -59,7 +59,7 @@ struct jit_brgemm_matmul_copy_a_t {
|
||||
const void *tr_src;
|
||||
const void *zp_b_compensation_buffer_ptr;
|
||||
const void *zp_a_compensation_result_ptr;
|
||||
const void *zp_b_neg_value_ptr;
|
||||
const void *zp_b_neg_val_ptr;
|
||||
const void *zp_ab_comp_ptr;
|
||||
|
||||
dim_t current_K_start;
|
||||
|
@ -1296,6 +1296,7 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
bgmmc.dst_dt = dst_d.data_type();
|
||||
bgmmc.wei_dt = weights_d.data_type();
|
||||
bgmmc.orig_wei_dt = weights_d.data_type();
|
||||
bgmmc.wei_zp_dt = attr.zero_points_.get(DNNL_ARG_WEIGHTS).get_data_type();
|
||||
|
||||
bgmmc.with_reduce = mmd.reduce_desc.format_kind != format_kind::undef;
|
||||
bgmmc.reduce_dt
|
||||
@ -1396,19 +1397,19 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
bgmmc.with_src_scales = !src_scales.has_default_values();
|
||||
bgmmc.with_wei_scales = !wei_scales.has_default_values();
|
||||
if (bgmmc.with_wei_scales) {
|
||||
const auto wei_qmask_N = 1 << (bgmmc.ndims - 1);
|
||||
const auto wei_qmask_K = 1 << (bgmmc.ndims - 2);
|
||||
bgmmc.is_wei_scale_per_k = wei_scales.get_mask() & wei_qmask_K;
|
||||
bgmmc.is_wei_scale_per_n = wei_scales.get_mask() & wei_qmask_N;
|
||||
const auto &wei_scale_mask = wei_scales.get_mask();
|
||||
bgmmc.is_wei_scale_common = wei_scale_mask == 0;
|
||||
bgmmc.is_wei_scale_per_k = wei_scale_mask & 1 << (bgmmc.ndims - 2);
|
||||
bgmmc.is_wei_scale_per_n = wei_scale_mask & 1 << (bgmmc.ndims - 1);
|
||||
bgmmc.apply_scales_in_buffer_b = bgmmc.is_wei_scale_per_k
|
||||
&& bgmmc.with_wei_decompression && bgmmc.N * bgmmc.K != 1;
|
||||
bgmmc.wei_scales_dt = wei_scales.get_data_type();
|
||||
bgmmc.wei_scales_dt_sz = types::data_type_size(bgmmc.wei_scales_dt);
|
||||
bgmmc.wei_scales_k_group_size = wei_scales.get_group(0);
|
||||
bgmmc.wei_scales_k_gsize = wei_scales.get_group(0);
|
||||
|
||||
// only common and per-oc-channel scales are supported
|
||||
// only per-ic-channel scales is supprted with weight decompression
|
||||
VCONDCHECK_BG(wei_scales.get_mask() == 0 || bgmmc.is_wei_scale_per_n
|
||||
VCONDCHECK_BG(bgmmc.is_wei_scale_common || bgmmc.is_wei_scale_per_n
|
||||
|| IMPLICATION(bgmmc.is_wei_scale_per_k,
|
||||
bgmmc.with_wei_decompression),
|
||||
VERBOSE_UNSUPPORTED_SCALES_CFG);
|
||||
@ -1420,6 +1421,28 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
VCONDCHECK_BG(!(bgmmc.with_dst_scales && dst_scales.get_mask() > 0),
|
||||
VERBOSE_UNSUPPORTED_SCALES_CFG);
|
||||
|
||||
const auto &wei_zp = attr.zero_points_.get(DNNL_ARG_WEIGHTS);
|
||||
const auto has_wei_zp = !wei_zp.has_default_values();
|
||||
|
||||
if (has_wei_zp) {
|
||||
const auto wei_zp_mask = wei_zp.get_mask();
|
||||
bgmmc.is_wei_zp_common = wei_zp_mask == 0;
|
||||
bgmmc.is_wei_zp_per_k = wei_zp_mask & (1 << (bgmmc.ndims - 2));
|
||||
bgmmc.is_wei_zp_per_n = wei_zp_mask & (1 << (bgmmc.ndims - 1));
|
||||
bgmmc.wei_zp_dt = wei_zp.get_data_type();
|
||||
bgmmc.wei_zp_k_gsize = wei_zp.get_group(0);
|
||||
|
||||
VCONDCHECK_BG(wei_zp_mask == 0 || bgmmc.is_wei_zp_per_k
|
||||
|| bgmmc.is_wei_zp_per_n,
|
||||
VERBOSE_UNSUPPORTED_ZP_CFG);
|
||||
|
||||
// Check if K groups for scales and for zero points are identical
|
||||
VCONDCHECK_BG(
|
||||
IMPLICATION(bgmmc.is_wei_zp_per_k && bgmmc.is_wei_scale_per_k,
|
||||
bgmmc.wei_zp_k_gsize == bgmmc.wei_scales_k_gsize),
|
||||
VERBOSE_UNSUPPORTED_ZP_CFG);
|
||||
}
|
||||
|
||||
const auto &p = attr.post_ops_;
|
||||
bgmmc.with_sum = p.find(primitive_kind::sum) != -1;
|
||||
const int eltwise_ind = p.find(primitive_kind::eltwise);
|
||||
@ -1532,9 +1555,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
bgmmc.transposed_B = bm_conf_utils.check_is_transposed(bgmmc.wei_tag)
|
||||
|| bgmmc.wei_tag == adbc;
|
||||
bgmmc.use_buffer_b = bm_conf_utils.use_buffer_b();
|
||||
bgmmc.req_transpose_scales = bgmmc.apply_scales_in_buffer_b
|
||||
&& bgmmc.is_wei_scale_per_k && bgmmc.is_wei_scale_per_n
|
||||
&& bgmmc.transposed_B;
|
||||
|
||||
if ((bm_conf_utils.is_f32_f16() || bm_conf_utils.is_f32_bf16())
|
||||
&& is_superset(bgmmc.isa, avx2) && bm_conf_utils.use_buffer_b()) {
|
||||
@ -1719,14 +1739,6 @@ status_t init_brgemm_matmul_conf(cpu_isa_t isa, brgemm_matmul_conf_t &bgmmc,
|
||||
= bm_conf_utils.wei_down_convert_to_vnni();
|
||||
}
|
||||
|
||||
// This setting must be updated post blocking as it has a dependency on
|
||||
// `bgmmc.K_blk`. See `gK_and_K_blk_are_divisible` comment.
|
||||
if (bgmmc.is_wei_scale_per_k) {
|
||||
const auto gK = bgmmc.wei_scales_k_group_size;
|
||||
bgmmc.gK_and_K_blk_are_divisible = gK > 1
|
||||
&& ((bgmmc.K_blk % gK == 0) || (gK % bgmmc.K_blk == 0));
|
||||
}
|
||||
|
||||
VCHECK_BG(bm_conf_utils.set_B_flags(weights_md), VERBOSE_BLOCKING_FAIL, "");
|
||||
|
||||
bgmmc.M_tail = bgmmc.is_runtime_M ? 0 : bgmmc.M % bgmmc.M_blk;
|
||||
|
@ -136,6 +136,7 @@ struct brgemm_matmul_conf_t {
|
||||
data_type_t reduce_dt;
|
||||
data_type_t orig_src_dt;
|
||||
data_type_t orig_wei_dt;
|
||||
|
||||
int nthr;
|
||||
int nthr_k = 1, nthr_m = 1, nthr_n = 1, nthr_b = 1;
|
||||
|
||||
@ -227,24 +228,27 @@ struct brgemm_matmul_conf_t {
|
||||
bool is_runtime_M = false;
|
||||
bool is_runtime_N = false;
|
||||
bool is_runtime_K = false;
|
||||
bool extendable_k = false;
|
||||
bool is_src_batch_layout_trivial = false;
|
||||
bool is_wei_batch_layout_trivial = false;
|
||||
bool is_dst_batch_layout_trivial = false;
|
||||
|
||||
// Attributes related to quantization
|
||||
// Scales
|
||||
bool apply_scales_in_buffer_b = false;
|
||||
size_t wei_scales_dt_sz = 0;
|
||||
bool is_wei_scale_per_n = false;
|
||||
bool is_wei_scale_per_k = false;
|
||||
bool req_transpose_scales = false;
|
||||
bool apply_scales_in_buffer_b = false;
|
||||
// For generic cases, when groups are selected the way they can't divide a
|
||||
// K_blk in equal pieces, it gets really hard to call a kernel with a
|
||||
// single "per_N line" of scales. In this case weights will be copied
|
||||
// to a larger memory buffer and used like a full tensor.
|
||||
// TODO: convert to a method. State must be set in a specific place of the
|
||||
// initialization as relies on blocking.
|
||||
bool gK_and_K_blk_are_divisible = false;
|
||||
size_t wei_scales_dt_sz = 0;
|
||||
dim_t wei_scales_k_group_size = 0;
|
||||
bool is_wei_scale_common = false;
|
||||
dim_t wei_scales_k_gsize = 0;
|
||||
data_type_t wei_scales_dt = data_type::undef;
|
||||
bool extendable_k = false;
|
||||
|
||||
// Zero points
|
||||
dim_t wei_zp_k_gsize = 0;
|
||||
bool is_wei_zp_per_k = false;
|
||||
bool is_wei_zp_per_n = false;
|
||||
bool is_wei_zp_common = false;
|
||||
data_type_t wei_zp_dt = data_type::undef;
|
||||
|
||||
bool is_gemv = false;
|
||||
|
||||
|
@ -319,8 +319,7 @@
|
||||
--dt=f32:s4:f32,f32:u4:f32
|
||||
--wtag=any,ab,ba
|
||||
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
|
||||
--attr-zero-points=,wei:common:2
|
||||
# ,wei:per_oc:s4:4096x1 # PER_OC skips for all matmul implemetations
|
||||
--attr-zero-points=,wei:common:2,wei:per_oc:s4
|
||||
--attr-fpmath=strict:true
|
||||
1x4096:4096x4096
|
||||
|
||||
@ -328,7 +327,32 @@
|
||||
--dt=f32:s8:f32,f32:u8:f32
|
||||
--wtag=any,ab,ba
|
||||
--attr-scales=,wei:common:0.5:f32,wei:per_ocic:f32:32x1
|
||||
--attr-zero-points=,wei:common:2
|
||||
# ,wei:per_ocic:s8:32x1 # PER_OC skips for all matmul implementations
|
||||
--attr-zero-points=,wei:common:2,wei:per_ocic:s8:32x1
|
||||
--attr-fpmath=strict:true
|
||||
1x4096:4096x4096
|
||||
|
||||
|
||||
## Additional grouped scales/ZP testing
|
||||
--reset
|
||||
--dt=f16:s4:f16
|
||||
--wtag=abc,acb
|
||||
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
|
||||
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
|
||||
--attr-fpmath=f16:true
|
||||
12x4x576:12x576x192
|
||||
|
||||
--reset
|
||||
--dt=bf16:s4:bf16
|
||||
--wtag=abc,acb
|
||||
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
|
||||
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
|
||||
--attr-fpmath=bf16:true
|
||||
12x4x576:12x576x192
|
||||
|
||||
--reset
|
||||
--dt=f32:s4:f32
|
||||
--wtag=abc,acb
|
||||
--attr-scales=,wei:common:2,wei:per_oc:f16,wei:per_ocic:f16:192x1
|
||||
--attr-zero-points=,wei:common:2,wei:per_oc:s4,wei:per_ocic:s4:192x1
|
||||
--attr-fpmath=strict:true
|
||||
12x4x576:12x576x192
|
||||
|
Reference in New Issue
Block a user