mirror of
https://github.com/uxlfoundation/oneDNN.git
synced 2025-10-20 18:43:49 +08:00
5681 lines
208 KiB
C++
5681 lines
208 KiB
C++
/*******************************************************************************
|
|
* Copyright 2021-2025 Intel Corporation
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*******************************************************************************/
|
|
|
|
#include "common/c_types_map.hpp"
|
|
#include "common/nstl.hpp"
|
|
#include "common/type_helpers.hpp"
|
|
#include "common/utils.hpp"
|
|
#include "cpu/x64/jit_generator.hpp"
|
|
|
|
#include "cpu/x64/matmul/brgemm_matmul_copy_utils.hpp"
|
|
|
|
namespace dnnl {
|
|
namespace impl {
|
|
namespace cpu {
|
|
namespace x64 {
|
|
namespace matmul {
|
|
|
|
using namespace dnnl::impl::format_tag;
|
|
using namespace dnnl::impl::utils;
|
|
using namespace Xbyak;
|
|
|
|
#define GET_OFF(x) offsetof(ctx_t, x)
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_a_impl_t : public jit_brgemm_matmul_copy_a_t,
|
|
public jit_generator_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_impl_t)
|
|
|
|
jit_brgemm_matmul_copy_a_impl_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_a_t(conf)
|
|
, jit_generator_t(jit_name())
|
|
, typesize_(conf_->a_dt_sz)
|
|
, tr_typesize_(conf_->tr_a_dt_sz)
|
|
, vnni_granularity_(data_type_vnni_granularity(conf_->src_dt))
|
|
, k_step_(vlen_ / nstl::max(typesize_, tr_typesize_))
|
|
, src_stride_(conf_->copy_A_src_stride)
|
|
, tr_src_stride_((conf_->use_buffer_a_tail_only
|
|
? static_cast<dim_t>(conf_->wei_k_blk)
|
|
: conf_->LDA)
|
|
* tr_typesize_)
|
|
, do_compute_compensation_(
|
|
conf_->has_zero_point_b && !conf_->with_wei_decompression)
|
|
, avx512_core_dot_product_(
|
|
do_compute_compensation_ && !isa_has_int8_vnni(conf->isa))
|
|
// See the note in `create_brgemm_matmul_copy_b` why `orig_src_dt` used.
|
|
, use_fp16_instructions_(conf_->isa == avx512_core_fp16
|
|
&& conf_->orig_src_dt == data_type::f16
|
|
&& conf_->src_dt == data_type::f32)
|
|
, k_loop_unroll_(is_ymm_ ? 7 : 16)
|
|
, vmm_copy_idx_(is_ymm_ ? 13
|
|
: avx512_core_dot_product_ ? 27
|
|
: 29) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
|
|
static constexpr int vlen_ = vreg_traits_t<Vmm>::vlen;
|
|
static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
|
|
static constexpr int num_comp_acc_ = is_ymm_ ? 7 : 8;
|
|
|
|
const int typesize_;
|
|
const int tr_typesize_;
|
|
const int vnni_granularity_;
|
|
const int k_step_;
|
|
const dim_t src_stride_;
|
|
const dim_t tr_src_stride_;
|
|
const bool do_compute_compensation_;
|
|
const bool avx512_core_dot_product_;
|
|
const bool use_fp16_instructions_;
|
|
|
|
const int k_loop_unroll_;
|
|
const int vmm_copy_idx_;
|
|
|
|
opmask_t kTail_load = k7;
|
|
opmask_t kTail_store = k6;
|
|
opmask_t kTail_comp = k5;
|
|
|
|
reg64_t reg_src = rax;
|
|
reg64_t reg_tr_src = rbx;
|
|
reg64_t reg_K_start = abi_not_param1;
|
|
|
|
reg64_t reg_zp_comp_buf_ptr = rdx;
|
|
reg64_t reg_zp_comp_res_ptr = rsi;
|
|
|
|
reg64_t reg_M_blk = r9;
|
|
reg64_t reg_K_blk = r10;
|
|
reg64_t reg_batch = r11;
|
|
reg64_t reg_aux_src = r12;
|
|
reg64_t reg_aux_tr_src = r13;
|
|
reg64_t regq_tmp = r14;
|
|
reg64_t imm_addr64 = r15;
|
|
reg64_t reg_zp_ab_comp_ptr = imm_addr64;
|
|
reg64_t reg_zp_b_neg_val_ptr = reg_K_blk;
|
|
|
|
// Required in every dot product for INT8 non-VNNI computation.
|
|
Vmm vmm_ones_words = Vmm(28);
|
|
Vmm vmm_dot_product_temp = Vmm(29);
|
|
|
|
Vmm vmm_comp_mul = Vmm(is_ymm_ ? 14 : 30); // 1s
|
|
Vmm vmm_comp_add = Vmm(is_ymm_ ? 15 : 31); // 128
|
|
|
|
// Allows to shift A data by 128 for s8s8 problem for AVX512 in copy
|
|
// routine, not in compute kernel. It's disabled for now, as it
|
|
// requires setting some hint to brgemm kernel to avoid double shifting
|
|
const bool allow_input_shift_for_s8s8 = false;
|
|
|
|
Vmm get_vmm_comp_acc(int i) {
|
|
assert(i >= 0 && i < num_comp_acc_);
|
|
return Vmm(i);
|
|
}
|
|
|
|
Vmm get_vmm_copy(int i) {
|
|
assert(i >= 0 && i < k_loop_unroll_);
|
|
return Vmm(vmm_copy_idx_ - i);
|
|
}
|
|
|
|
void load_vmm(int idx, int offset) {}
|
|
void store_vmm(int idx, int offset) {}
|
|
void load_tail(int k_tail, size_t offset) {}
|
|
void store_tail(int k_tail, size_t offset) {}
|
|
void reduce_compensation_across_accumulators(int num_accumulators);
|
|
void copy_K_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter);
|
|
void copy_M_loop(bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter);
|
|
inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
|
|
if (!avx512_core_dot_product_)
|
|
vpdpbusd(v1, v2, v3, get_encoding());
|
|
else {
|
|
vpmaddubsw(vmm_dot_product_temp, v2, v3);
|
|
vpmaddwd(
|
|
vmm_dot_product_temp, vmm_dot_product_temp, vmm_ones_words);
|
|
vpaddd(v1, v1, vmm_dot_product_temp);
|
|
}
|
|
}
|
|
void generate() override;
|
|
};
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_vmm(int idx, int offset) {
|
|
const auto addr = EVEX_compress_addr(reg_src, offset);
|
|
if (use_fp16_instructions_) {
|
|
vcvtph2psx(get_vmm_copy(idx), addr);
|
|
} else {
|
|
vmovdqu8(get_vmm_copy(idx), addr);
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Ymm>::load_vmm(int idx, int offset) {
|
|
uni_vmovups(get_vmm_copy(idx), ptr[reg_src + offset]);
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::store_vmm(int idx, int offset) {
|
|
auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
|
|
vmovdqu8(tr_src_addr, get_vmm_copy(idx));
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Ymm>::store_vmm(int idx, int offset) {
|
|
uni_vmovups(ptr[reg_tr_src + offset], get_vmm_copy(idx));
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::load_tail(
|
|
int k_tail, size_t offset) {
|
|
const auto kmovx = [this](Opmask k, size_t q) {
|
|
if (conf_->is_bf32) {
|
|
mov(regq_tmp.cvt32(), q);
|
|
jit_generator_t::kmovw(k, regq_tmp.cvt32());
|
|
} else {
|
|
mov(regq_tmp, q);
|
|
jit_generator_t::kmovq(k, regq_tmp);
|
|
}
|
|
};
|
|
|
|
const size_t dt_step
|
|
= conf_->is_bf32 || use_fp16_instructions_ ? 1 : typesize_;
|
|
const size_t tail_mask_load = size_t(((size_t)1 << (dt_step * k_tail)) - 1);
|
|
kmovx(kTail_load, tail_mask_load);
|
|
const int k_tail_st = rnd_up(k_tail, vnni_granularity_);
|
|
const size_t full_mask
|
|
= conf_->is_bf32 ? ((size_t)1 << 16) - 1 : 0xffffffffffffffff;
|
|
const size_t tail_mask_store = k_tail_st == k_step_
|
|
? full_mask
|
|
: size_t(((size_t)1 << (dt_step * k_tail_st)) - 1);
|
|
kmovx(kTail_store, tail_mask_store);
|
|
|
|
auto zmm_tail = get_vmm_copy(0) | kTail_load | T_z;
|
|
auto load_addr = EVEX_compress_addr(reg_src, offset * typesize_);
|
|
if (conf_->is_bf32)
|
|
vmovups(zmm_tail, load_addr);
|
|
else if (use_fp16_instructions_)
|
|
vcvtph2psx(zmm_tail, load_addr);
|
|
else
|
|
vmovdqu8(zmm_tail, load_addr);
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Ymm>::load_tail(
|
|
int k_tail, size_t offset) {
|
|
const auto vmm_tail = get_vmm_copy(0);
|
|
load_bytes(vmm_tail, reg_src, offset * typesize_, k_tail * typesize_);
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Zmm>::store_tail(
|
|
int k_tail, size_t offset) {
|
|
auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset * tr_typesize_);
|
|
if (conf_->is_bf32) {
|
|
Ymm ymm_downcvt_bf16 = Ymm(get_vmm_copy(0).getIdx());
|
|
vcvtneps2bf16(ymm_downcvt_bf16, get_vmm_copy(0));
|
|
vmovdqu16(tr_src_addr, ymm_downcvt_bf16 | kTail_store);
|
|
} else if (use_fp16_instructions_) {
|
|
vmovups(tr_src_addr, get_vmm_copy(0) | kTail_store);
|
|
} else
|
|
vmovdqu8(tr_src_addr, get_vmm_copy(0) | kTail_store);
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Ymm>::store_tail(
|
|
int k_tail, size_t offset) {
|
|
const int k_tail_st = rnd_up(k_tail, vnni_granularity_);
|
|
const auto vmm_tail = get_vmm_copy(0);
|
|
store_bytes(
|
|
vmm_tail, reg_tr_src, offset * tr_typesize_, k_tail_st * typesize_);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_impl_t<
|
|
Vmm>::reduce_compensation_across_accumulators(int num_accumulators) {
|
|
int num = num_accumulators;
|
|
while (num > 1) {
|
|
for (int i = 0; i < num / 2; i++) {
|
|
const auto vmm_acc0 = get_vmm_comp_acc(i);
|
|
const auto vmm_acc1 = get_vmm_comp_acc(div_up(num, 2) + i);
|
|
uni_vpaddd(vmm_acc0, vmm_acc0, vmm_acc1);
|
|
}
|
|
num = div_up(num, 2);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Vmm>::copy_K_loop(
|
|
bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
|
|
const int K_blk = is_K_tail ? conf_->K % conf_->K_blk
|
|
: nstl::min(conf_->K, conf_->K_blk);
|
|
const int k_tail = K_blk % k_step_;
|
|
const int num_k_iters = K_blk / k_step_;
|
|
const int num_acc = utils::saturate(1, (int)num_comp_acc_, num_k_iters);
|
|
|
|
if (do_compute_compensation_) {
|
|
for (int i = 0; i < num_acc; i++) {
|
|
const auto vmm_acc = get_vmm_comp_acc(i);
|
|
uni_vpxor(vmm_acc, vmm_acc, vmm_acc);
|
|
}
|
|
}
|
|
|
|
auto maybe_compute_compensation = [this, num_acc](int k_idx, Vmm vmm_copy) {
|
|
if (do_compute_compensation_) {
|
|
const auto vmm_comp_acc = get_vmm_comp_acc(k_idx % num_acc);
|
|
if (conf_->src_dt == data_type::s8)
|
|
dot_product(vmm_comp_acc, vmm_comp_mul, vmm_copy);
|
|
else
|
|
dot_product(vmm_comp_acc, vmm_copy, vmm_comp_mul);
|
|
}
|
|
};
|
|
|
|
for (int kb = 0; kb < div_up(num_k_iters, k_loop_unroll_); kb++) {
|
|
const int k_end
|
|
= nstl::min(k_loop_unroll_, num_k_iters - kb * k_loop_unroll_);
|
|
for (int k = 0; k < k_end; k++) {
|
|
const int k_idx = kb * k_loop_unroll_ + k;
|
|
const size_t offset
|
|
= static_cast<size_t>(k_idx) * k_step_ * typesize_;
|
|
load_vmm(k, offset);
|
|
maybe_compute_compensation(k_idx, get_vmm_copy(k));
|
|
}
|
|
if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) {
|
|
for (int k = 0; k < k_end; k++)
|
|
vpaddb(get_vmm_copy(k), get_vmm_copy(k), vmm_comp_add);
|
|
}
|
|
if (conf_->is_bf32) {
|
|
assert(typesize_ != tr_typesize_);
|
|
int k = 0;
|
|
const int k_end_2 = rnd_dn(k_end, 2);
|
|
for (; k < k_end_2; k += 2) {
|
|
const size_t offset = ((size_t)kb * k_loop_unroll_ + k)
|
|
* k_step_ * tr_typesize_;
|
|
auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
|
|
|
|
auto zmm_src = get_vmm_copy(k);
|
|
auto zmm_src_next = get_vmm_copy(k + 1);
|
|
|
|
vcvtne2ps2bf16(zmm_src, zmm_src_next, zmm_src);
|
|
vmovups(tr_src_addr, zmm_src);
|
|
}
|
|
if (k < k_end) {
|
|
const size_t offset = ((size_t)kb * k_loop_unroll_ + k)
|
|
* k_step_ * tr_typesize_;
|
|
auto tr_src_addr = EVEX_compress_addr(reg_tr_src, offset);
|
|
Ymm ymm_downcvt_bf16 = Ymm(get_vmm_copy(k).getIdx());
|
|
vcvtneps2bf16(ymm_downcvt_bf16, get_vmm_copy(k));
|
|
vmovdqu16(tr_src_addr, ymm_downcvt_bf16);
|
|
}
|
|
} else {
|
|
for (int k = 0; k < k_end; k++) {
|
|
const size_t offset
|
|
= (static_cast<size_t>(kb) * k_loop_unroll_ + k)
|
|
* k_step_ * tr_typesize_;
|
|
store_vmm(k, offset);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (k_tail > 0) {
|
|
load_tail(k_tail, num_k_iters * k_step_);
|
|
maybe_compute_compensation(0, get_vmm_copy(0));
|
|
|
|
if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required)
|
|
vpaddb(get_vmm_copy(0), get_vmm_copy(0), vmm_comp_add);
|
|
|
|
store_tail(k_tail, num_k_iters * k_step_);
|
|
}
|
|
|
|
if (do_compute_compensation_) {
|
|
reduce_compensation_across_accumulators(num_acc);
|
|
|
|
const auto addr_buf = ptr[reg_zp_comp_buf_ptr];
|
|
if (!is_first_K_iter)
|
|
uni_vpaddd(get_vmm_comp_acc(0), get_vmm_comp_acc(0), addr_buf);
|
|
if (!is_last_K_iter) {
|
|
uni_vmovups(addr_buf, get_vmm_comp_acc(0));
|
|
return;
|
|
}
|
|
|
|
// is_last_K_iter == true: we need to reduce values within acc
|
|
// register, add mixed ab_compensation component if any, multiply
|
|
// it by negative zp_b_value and finally store the result
|
|
|
|
// step 1: reduce values within acc register
|
|
const auto ymm_red0 = Ymm(get_vmm_comp_acc(0).getIdx());
|
|
const auto ymm_red1 = Ymm(get_vmm_comp_acc(1).getIdx());
|
|
if (!is_ymm_) {
|
|
vextracti64x4(ymm_red1, Zmm(get_vmm_comp_acc(0).getIdx()), 1);
|
|
vphaddd(ymm_red0, ymm_red0, ymm_red1);
|
|
}
|
|
uni_vpxor(ymm_red1, ymm_red1, ymm_red1);
|
|
uni_vphaddd(ymm_red0, ymm_red0, ymm_red1);
|
|
uni_vphaddd(ymm_red0, ymm_red0, ymm_red1);
|
|
const auto xmm_red1 = Xmm(ymm_red1.getIdx());
|
|
vextractf128(xmm_red1, ymm_red0, 1);
|
|
uni_vpaddd(ymm_red0, ymm_red0, ymm_red1);
|
|
|
|
const auto vmm_in_mask = get_vmm_comp_acc(1);
|
|
if (is_ymm_) {
|
|
static const uint32_t mask_in[8]
|
|
= {0xffffffff, 0, 0, 0, 0, 0, 0, 0};
|
|
mov(regq_tmp, reinterpret_cast<size_t>(mask_in));
|
|
vmovups(vmm_in_mask, ptr[regq_tmp]);
|
|
}
|
|
|
|
// step 2: add -K * zp_a_val as mixed ab_compensation component
|
|
if (conf_->src_zp_type != brgemm_broadcast_t::none) {
|
|
assert(conf_->src_zp_type == brgemm_broadcast_t::per_tensor);
|
|
reg64_t reg_zp_ab_comp_ptr = imm_addr64;
|
|
mov(reg_zp_ab_comp_ptr, ptr[param1 + GET_OFF(zp_ab_comp_ptr)]);
|
|
if (is_ymm_) {
|
|
const auto vmm_zp = get_vmm_comp_acc(2);
|
|
vmaskmovps(vmm_zp, vmm_in_mask, ptr[reg_zp_ab_comp_ptr]);
|
|
uni_vpaddd(ymm_red0, ymm_red0, vmm_zp);
|
|
} else {
|
|
const auto addr_ab_comp = zword_b[reg_zp_ab_comp_ptr];
|
|
const auto zmm_res = get_vmm_comp_acc(0) | kTail_comp;
|
|
vpaddd(zmm_res, get_vmm_comp_acc(0), addr_ab_comp);
|
|
}
|
|
}
|
|
|
|
// step 3: multiply by zp_b_val
|
|
mov(reg_zp_b_neg_val_ptr, ptr[param1 + GET_OFF(zp_b_neg_val_ptr)]);
|
|
const auto vmm_zp_b_neg_val = get_vmm_comp_acc(is_ymm_ ? 2 : 1);
|
|
uni_vbroadcastss(vmm_zp_b_neg_val, ptr[reg_zp_b_neg_val_ptr]);
|
|
uni_vpmulld(get_vmm_comp_acc(0), get_vmm_comp_acc(0), vmm_zp_b_neg_val);
|
|
|
|
// step 4: store the final result value
|
|
if (is_ymm_)
|
|
vmaskmovps(ptr[reg_zp_comp_res_ptr], vmm_in_mask, ymm_red0);
|
|
else
|
|
vmovups(ptr[reg_zp_comp_res_ptr], get_vmm_comp_acc(0) | kTail_comp);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Vmm>::copy_M_loop(
|
|
bool is_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
|
|
|
|
if (do_compute_compensation_) {
|
|
mov(imm_addr64, 1);
|
|
uni_vpbroadcastb(vmm_comp_mul, imm_addr64.cvt8());
|
|
if (!(is_first_K_iter && is_last_K_iter))
|
|
mov(reg_zp_comp_buf_ptr,
|
|
ptr[param1 + GET_OFF(zp_b_compensation_buffer_ptr)]);
|
|
|
|
if (is_last_K_iter) {
|
|
mov(reg_zp_comp_res_ptr,
|
|
ptr[param1 + GET_OFF(zp_a_compensation_result_ptr)]);
|
|
if (!is_ymm_) {
|
|
mov(regq_tmp, 1);
|
|
jit_generator_t::kmovw(kTail_comp, imm_addr64.cvt32());
|
|
}
|
|
}
|
|
}
|
|
|
|
Label loop_M;
|
|
L(loop_M);
|
|
|
|
copy_K_loop(is_K_tail, is_first_K_iter, is_last_K_iter);
|
|
|
|
add(reg_src, src_stride_);
|
|
add(reg_tr_src, tr_src_stride_);
|
|
if (do_compute_compensation_) {
|
|
// shift comp pointers
|
|
if (!(is_first_K_iter && is_last_K_iter))
|
|
add(reg_zp_comp_buf_ptr, sizeof(int32_t) * 16);
|
|
if (is_last_K_iter) add(reg_zp_comp_res_ptr, sizeof(int32_t));
|
|
}
|
|
|
|
dec(reg_M_blk);
|
|
jnz(loop_M, T_NEAR);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_impl_t<Vmm>::generate() {
|
|
preamble();
|
|
|
|
if (avx512_core_dot_product_) {
|
|
mov(regq_tmp.cvt16(), 1);
|
|
vpbroadcastw(vmm_ones_words, regq_tmp.cvt16());
|
|
}
|
|
|
|
mov(reg_src, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(reg_K_blk, ptr[param1 + GET_OFF(current_K_blk)]);
|
|
mov(reg_M_blk, ptr[param1 + GET_OFF(current_M_blk)]);
|
|
|
|
if (allow_input_shift_for_s8s8 && conf_->s8s8_compensation_required) {
|
|
mov(imm_addr64, 128);
|
|
uni_vpbroadcastb(vmm_comp_add, imm_addr64.cvt8());
|
|
}
|
|
|
|
auto copy_body = [this](bool is_first_K_iter, bool is_last_K_iter) {
|
|
Label copy_body_done;
|
|
// might be different from conf_->K_tail
|
|
const dim_t K_blk_tail
|
|
= conf_->K_tail > 0 ? conf_->K % conf_->K_blk : 0;
|
|
if (K_blk_tail > 0) {
|
|
Label not_K_tail;
|
|
cmp(reg_K_blk, K_blk_tail);
|
|
jne(not_K_tail, T_NEAR);
|
|
copy_M_loop(true, is_first_K_iter, is_last_K_iter);
|
|
jmp(copy_body_done, T_NEAR);
|
|
|
|
L(not_K_tail);
|
|
}
|
|
|
|
copy_M_loop(false, is_first_K_iter, is_last_K_iter);
|
|
L(copy_body_done);
|
|
};
|
|
|
|
Label done;
|
|
if (do_compute_compensation_) {
|
|
assert(conf_->wei_zp_type == brgemm_broadcast_t::per_tensor);
|
|
|
|
mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
|
|
const auto last_K_threshold
|
|
= rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk;
|
|
Label not_first, not_first_not_last;
|
|
cmp(reg_K_start, 0);
|
|
jne(not_first, T_NEAR);
|
|
{
|
|
// first K iteration
|
|
Label first_not_last;
|
|
cmp(reg_K_start, last_K_threshold);
|
|
jl(first_not_last, T_NEAR);
|
|
copy_body(true, true);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(first_not_last);
|
|
copy_body(true, false);
|
|
jmp(done, T_NEAR);
|
|
}
|
|
|
|
L(not_first);
|
|
cmp(reg_K_start, last_K_threshold);
|
|
jl(not_first_not_last, T_NEAR);
|
|
|
|
copy_body(false, true);
|
|
jmp(done, T_NEAR);
|
|
L(not_first_not_last);
|
|
}
|
|
copy_body(false, false);
|
|
L(done);
|
|
|
|
postamble();
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_a_impl_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_a_impl_t<Ymm>;
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_a_transposed_impl_t
|
|
: public jit_brgemm_matmul_copy_a_t,
|
|
public jit_generator_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_a_transposed_impl_t)
|
|
|
|
jit_brgemm_matmul_copy_a_transposed_impl_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_a_t(conf)
|
|
, jit_generator_t(jit_name())
|
|
, typesize(conf_->a_dt_sz)
|
|
, tr_typesize(conf_->tr_a_dt_sz)
|
|
, rows_step(16)
|
|
, columns_step(rows_step)
|
|
, src_stride(conf_->copy_A_src_stride)
|
|
, dst_stride(conf_->LDA * tr_typesize)
|
|
, m_loop_src_shift(columns_step * typesize)
|
|
, m_loop_dst_shift(columns_step * dst_stride)
|
|
, k_loop_src_shift(rows_step * src_stride)
|
|
, k_loop_dst_shift(rows_step * tr_typesize)
|
|
, is_f32(conf_->src_dt == data_type::f32)
|
|
, is_bf32(conf_->is_bf32)
|
|
, is_dynamic_src_ld(conf_->is_runtime_M)
|
|
// See the note in `create_brgemm_matmul_copy_b` why `orig_src_dt` used.
|
|
, use_fp16_instructions_(conf_->isa == avx512_core_fp16
|
|
&& conf_->orig_src_dt == data_type::f16
|
|
&& conf_->src_dt == data_type::f32) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
|
|
const size_t typesize;
|
|
const size_t tr_typesize;
|
|
const int rows_step;
|
|
const int columns_step;
|
|
const dim_t src_stride, dst_stride;
|
|
const dim_t m_loop_src_shift;
|
|
const dim_t m_loop_dst_shift;
|
|
const dim_t k_loop_src_shift;
|
|
const dim_t k_loop_dst_shift;
|
|
const bool is_f32;
|
|
const bool is_bf32;
|
|
const bool is_dynamic_src_ld;
|
|
const bool use_fp16_instructions_;
|
|
|
|
opmask_t kFFFF = k1;
|
|
opmask_t k3333 = k1;
|
|
opmask_t k5555 = k2;
|
|
opmask_t kAAAA = k3;
|
|
opmask_t kAA = k4;
|
|
opmask_t kCCCC = k4;
|
|
opmask_t k55 = k5;
|
|
opmask_t k0F0F = k5;
|
|
opmask_t kCC = k6;
|
|
opmask_t kF0F0 = k6;
|
|
opmask_t k33 = k7;
|
|
opmask_t kTail = is_f32 ? k7 : k1;
|
|
|
|
reg64_t regq_tmp = r15;
|
|
reg32_t regw_tmp = regq_tmp.cvt32();
|
|
reg64_t reg_k_src = r14;
|
|
reg64_t reg_k_dst = r13;
|
|
reg64_t reg_m_src = r12;
|
|
reg64_t reg_m_dst = r11;
|
|
reg64_t reg_aux_src0 = r10;
|
|
reg64_t reg_aux_src1 = r9;
|
|
reg64_t reg_loop_k = rax;
|
|
reg64_t reg_loop_m = rbx;
|
|
reg64_t imm_addr64 = rdx;
|
|
// Note: this must be assigned to rcx as it's used in shl instruction,
|
|
// clashes with abi_param1 on Windows OS
|
|
reg64_t reg_opmask_shift_compute = rcx;
|
|
|
|
Xbyak::Zmm vidx1 = zmm31;
|
|
Xbyak::Zmm vidx2 = zmm30;
|
|
Xbyak::Zmm vidx3 = zmm29;
|
|
Xbyak::Zmm vidx4 = zmm28;
|
|
Xbyak::Zmm vidx5 = zmm27;
|
|
Xbyak::Zmm zmm_tmp = zmm26;
|
|
|
|
constexpr static int current_M_blk_offt_ = 0;
|
|
constexpr static int src_offt_ = 8;
|
|
constexpr static int tr_src_offt_ = 16;
|
|
constexpr static int current_K_blk_offt_ = 24;
|
|
constexpr static int dynamic_src_ld_offt_ = 32;
|
|
constexpr static int dynamic_src_ld_x_2_offt_ = 40;
|
|
constexpr static int dynamic_src_ld_x_kstep_offt_ = 48;
|
|
constexpr static int stack_space_needed_ = 56;
|
|
|
|
void vmovdqa64(Vmm v, const int64_t *addr) {
|
|
mov(imm_addr64, reinterpret_cast<size_t>(addr));
|
|
jit_generator_t::vmovdqa64(v, ptr[imm_addr64]);
|
|
}
|
|
|
|
void vmovdqa32(Vmm v, const int32_t *addr) {
|
|
mov(imm_addr64, reinterpret_cast<size_t>(addr));
|
|
jit_generator_t::vmovdqa32(v, ptr[imm_addr64]);
|
|
}
|
|
|
|
void kmovw(Opmask mask_reg, size_t mask) {
|
|
mov(regw_tmp, mask);
|
|
jit_generator_t::kmovw(mask_reg, regw_tmp);
|
|
}
|
|
|
|
void transpose_f32(reg64_t dst, reg64_t src, int nrows, int ncolumns);
|
|
void transpose_bf16(reg64_t dst, reg64_t src, int nrows, int ncolumns);
|
|
void deploy_transpose(reg64_t dst, reg64_t src, int nrows, int ncolumns);
|
|
void init_masks();
|
|
void generate() override;
|
|
};
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_bf16(
|
|
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0
|
|
&& ncolumns <= columns_step);
|
|
if (!nrows) return;
|
|
|
|
auto src_zmm = [](int i) { return Zmm(i); };
|
|
|
|
auto src_ymm = [](int i) {
|
|
assert(i >= 0 && i < 16);
|
|
return Ymm(i);
|
|
};
|
|
|
|
Label transpose_bf16_done;
|
|
const bool dynamic_column_size = ncolumns == 0 && is_dynamic_src_ld;
|
|
auto kmovx
|
|
= [this, dynamic_column_size](Opmask k, unsigned w,
|
|
bool load_mask_stage = false, bool use_word_sz = false) {
|
|
if (dynamic_column_size && load_mask_stage) {
|
|
// reg_opmask_shift_compute is rcx, and we need cl for the shift
|
|
mov(reg_opmask_shift_compute, reg_loop_m);
|
|
mov(regq_tmp, 1);
|
|
shl(regq_tmp, cl);
|
|
sub(regq_tmp, 1);
|
|
} else
|
|
mov(regw_tmp, w);
|
|
if (use_word_sz)
|
|
jit_generator_t::kmovw(k, regw_tmp);
|
|
else
|
|
jit_generator_t::kmovd(k, regw_tmp);
|
|
};
|
|
|
|
auto store = [this, dst](Zmm r, int i) {
|
|
auto addr = EVEX_compress_addr(dst, i * dst_stride);
|
|
vmovdqu16(addr, r | kTail);
|
|
};
|
|
|
|
const int load_mask
|
|
= ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff;
|
|
kmovx(kFFFF, load_mask, true, is_bf32);
|
|
|
|
for (int i = 0; i < nrows / 2; i++) {
|
|
auto idx0 = 2 * i;
|
|
auto idx1 = 2 * i + 1;
|
|
auto zmm_src0 = src_zmm(idx0);
|
|
auto zmm_src1 = src_zmm(idx1);
|
|
if (is_dynamic_src_ld) {
|
|
if (i == 0) {
|
|
mov(reg_aux_src0, src);
|
|
mov(reg_aux_src1, src);
|
|
add(reg_aux_src1, ptr[rsp + dynamic_src_ld_offt_]);
|
|
} else {
|
|
add(reg_aux_src0, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
add(reg_aux_src1, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
}
|
|
}
|
|
auto src_addr_0 = is_dynamic_src_ld
|
|
? ptr[reg_aux_src0]
|
|
: EVEX_compress_addr(src, idx0 * src_stride);
|
|
auto src_addr_1 = is_dynamic_src_ld
|
|
? ptr[reg_aux_src1]
|
|
: EVEX_compress_addr(src, idx1 * src_stride);
|
|
if (is_bf32) {
|
|
vmovups(zmm_src0 | kFFFF | T_z, src_addr_0);
|
|
vmovups(zmm_src1 | kFFFF | T_z, src_addr_1);
|
|
vcvtne2ps2bf16(zmm_src0, zmm_src1, zmm_src0);
|
|
} else {
|
|
auto src1 = src_ymm(idx1);
|
|
vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr_0);
|
|
vmovdqu16(zmm_src1 | kFFFF | T_z, src_addr_1);
|
|
vinsertf64x4(zmm_src0, zmm_src0, src1, 1);
|
|
}
|
|
vpermw(zmm_src0, vidx5, zmm_src0);
|
|
}
|
|
|
|
// for odd numbers we need to mix row with zeroes
|
|
if (nrows % 2) {
|
|
int i = nrows / 2;
|
|
auto zmm_src0 = src_zmm(2 * i);
|
|
if (is_dynamic_src_ld) {
|
|
if (i == 0) {
|
|
mov(reg_aux_src0, src);
|
|
} else {
|
|
add(reg_aux_src0, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
}
|
|
}
|
|
auto src_addr = is_dynamic_src_ld
|
|
? ptr[reg_aux_src0]
|
|
: EVEX_compress_addr(src, 2 * i * src_stride);
|
|
if (is_bf32) {
|
|
vmovups(zmm_src0 | kFFFF | T_z, src_addr);
|
|
vcvtneps2bf16(Ymm(zmm_src0.getIdx()), zmm_src0);
|
|
} else
|
|
vmovdqu16(zmm_src0 | kFFFF | T_z, src_addr);
|
|
vpermw(zmm_src0, vidx5, zmm_src0);
|
|
}
|
|
|
|
for (int i = rnd_up(nrows, 2); i < rows_step; i += 2) {
|
|
vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
|
|
}
|
|
|
|
// swap 1
|
|
for (int i = 0; i < 4; i++) {
|
|
auto zmm0 = src_zmm(4 * i);
|
|
auto zmm1 = src_zmm(4 * i + 2);
|
|
auto tmp0 = src_zmm(4 * i + 1);
|
|
auto tmp1 = src_zmm(4 * i + 3);
|
|
|
|
vmovups(tmp0, zmm0);
|
|
vmovups(tmp1, zmm1);
|
|
|
|
vpermps(tmp0 | kAAAA, vidx3, zmm1);
|
|
vpermps(tmp1 | k5555, vidx3, zmm0);
|
|
}
|
|
// swap 2
|
|
int base_idx;
|
|
base_idx = 0;
|
|
for (int i = 0; i < 2; i++) {
|
|
auto zmm0 = src_zmm(base_idx + 2 * i + 1);
|
|
auto zmm1 = src_zmm(base_idx + 2 * i + 5);
|
|
|
|
auto tmp0 = src_zmm(base_idx + 2 * i);
|
|
auto tmp1 = src_zmm(base_idx + 2 * i + 4);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kAA, vidx2, zmm1);
|
|
vpermpd(tmp1 | k55, vidx2, zmm0);
|
|
}
|
|
base_idx = 8;
|
|
for (int i = 0; i < 2; i++) {
|
|
auto zmm0 = src_zmm(base_idx + 2 * i + 1);
|
|
auto zmm1 = src_zmm(base_idx + 2 * i + 5);
|
|
|
|
auto tmp0 = src_zmm(base_idx + 2 * i);
|
|
auto tmp1 = src_zmm(base_idx + 2 * i + 4);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kAA, vidx2, zmm1);
|
|
vpermpd(tmp1 | k55, vidx2, zmm0);
|
|
}
|
|
|
|
// swap 3
|
|
for (int i = 0; i < 4; i++) {
|
|
auto zmm0 = src_zmm(2 * i);
|
|
auto zmm1 = src_zmm(2 * i + 8);
|
|
|
|
auto tmp0 = src_zmm(2 * i + 1);
|
|
auto tmp1 = src_zmm(2 * i + 9);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kCC, vidx1, zmm1);
|
|
vpermpd(tmp1 | k33, vidx1, zmm0);
|
|
}
|
|
|
|
// all stores
|
|
for (int i = 0; i < 8; i++)
|
|
vextracti64x4(src_ymm(2 * i), src_zmm(2 * i + 1), 1);
|
|
|
|
auto get_vec_idx = [this](int col_idx) {
|
|
MAYBE_UNUSED(this);
|
|
assert(col_idx < columns_step && col_idx >= 0);
|
|
const int blk_sz = 4;
|
|
const int blk_idx = col_idx / blk_sz;
|
|
const int idx_within_blk = col_idx % blk_sz;
|
|
|
|
// 0 1 2 3 -> 0 2 1 3
|
|
const int mapped_blk_idx = 2 * blk_idx - (blk_idx / 2) * 3;
|
|
// 0 1 2 3 -> 1 0 3 2
|
|
const int mapped_idx_within_blk
|
|
= idx_within_blk + 1 - 2 * (idx_within_blk % 2);
|
|
return blk_sz * mapped_blk_idx + mapped_idx_within_blk;
|
|
};
|
|
const int rows_to_store = rnd_up(nrows, 2);
|
|
const int store_mask
|
|
= rows_to_store < rows_step ? (1 << rows_to_store) - 1 : 0xffff;
|
|
kmovx(kTail, store_mask);
|
|
|
|
const int columns_to_store = dynamic_column_size ? columns_step : ncolumns;
|
|
for (int col_idx = 0; col_idx < columns_to_store; col_idx++) {
|
|
store(src_zmm(get_vec_idx(col_idx)), col_idx);
|
|
if (dynamic_column_size) {
|
|
dec(reg_opmask_shift_compute);
|
|
jz(transpose_bf16_done, T_NEAR);
|
|
}
|
|
}
|
|
|
|
L(transpose_bf16_done);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::transpose_bf16(
|
|
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
|
|
assert(!"unsupported transpose_bf16 copy_a_transposed_impl");
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::transpose_f32(
|
|
reg64_t reg_dst, reg64_t reg_src, int nrows, int ncolumns) {
|
|
assert(!"unsupported transpose_f32 copy_a_transposed_impl");
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Ymm>::transpose_f32(
|
|
reg64_t reg_dst, reg64_t reg_src, int nrows, int ncolumns) {
|
|
Ymm ymm_tail_mask = ymm15;
|
|
Ymm ymm_upper_tail_mask = ymm14;
|
|
Xmm xmm_upper_tail_mask = xmm14;
|
|
Ymm ymm_tmp = ymm13;
|
|
|
|
// avx2 transpose is 8x8, but we need 16x16 transpose. We use four 8x8
|
|
// transposes as below.
|
|
// _ _T _ _
|
|
// |A, B| => |At, Ct|
|
|
// |C, D| |Bt, Dt|
|
|
|
|
constexpr int avx2_transpose_size = 8;
|
|
const int tail_size = ncolumns % avx2_transpose_size;
|
|
if (tail_size > 0) {
|
|
Xbyak::Reg64 reg_tmp = regq_tmp;
|
|
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, tail_size);
|
|
const int upper_xmm_tail_size = tail_size - 4;
|
|
if (upper_xmm_tail_size > 0)
|
|
init_f32_avx2_mask_ymm(
|
|
ymm_upper_tail_mask, reg_tmp, upper_xmm_tail_size);
|
|
}
|
|
|
|
const int A_rows = nstl::min(avx2_transpose_size, nrows);
|
|
const int A_columns = nstl::min(avx2_transpose_size, ncolumns);
|
|
jit_generator_t::transpose(reg_src, reg_dst, src_stride, dst_stride, A_rows,
|
|
A_columns, data_type::f32, ymm_tmp, ymm_tail_mask,
|
|
xmm_upper_tail_mask);
|
|
if (rows_step <= 8) return;
|
|
|
|
const dim_t src_B_offset = sizeof(float) * avx2_transpose_size;
|
|
const dim_t dst_B_offset = dst_stride * avx2_transpose_size;
|
|
const int B_rows = nstl::min(avx2_transpose_size, nrows);
|
|
const int B_columns = nstl::max(ncolumns - avx2_transpose_size, 0);
|
|
add(reg_src, src_B_offset);
|
|
add(reg_dst, dst_B_offset);
|
|
jit_generator_t::transpose(reg_src, reg_dst, src_stride, dst_stride, B_rows,
|
|
B_columns, data_type::f32, ymm_tmp, ymm_tail_mask,
|
|
xmm_upper_tail_mask);
|
|
|
|
const dim_t src_C_offset = src_stride * avx2_transpose_size;
|
|
const dim_t dst_C_offset = sizeof(float) * avx2_transpose_size;
|
|
const int C_rows = nstl::max(nrows - avx2_transpose_size, 0);
|
|
const int C_columns = nstl::min(avx2_transpose_size, ncolumns);
|
|
add(reg_src, -src_B_offset + src_C_offset);
|
|
add(reg_dst, -dst_B_offset + dst_C_offset);
|
|
jit_generator_t::transpose(reg_src, reg_dst, src_stride, dst_stride, C_rows,
|
|
C_columns, data_type::f32, ymm_tmp, ymm_tail_mask,
|
|
xmm_upper_tail_mask);
|
|
|
|
const dim_t src_D_offset = src_stride * avx2_transpose_size
|
|
+ sizeof(float) * avx2_transpose_size;
|
|
const dim_t dst_D_offset = dst_stride * avx2_transpose_size
|
|
+ sizeof(float) * avx2_transpose_size;
|
|
const int D_rows = nstl::max(nrows - avx2_transpose_size, 0);
|
|
const int D_columns = nstl::max(ncolumns - avx2_transpose_size, 0);
|
|
add(reg_src, -src_C_offset + src_D_offset);
|
|
add(reg_dst, -dst_C_offset + dst_D_offset);
|
|
jit_generator_t::transpose(reg_src, reg_dst, src_stride, dst_stride, D_rows,
|
|
D_columns, data_type::f32, ymm_tmp, ymm_tail_mask,
|
|
xmm_upper_tail_mask);
|
|
sub(reg_src, src_D_offset);
|
|
sub(reg_dst, dst_D_offset);
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Xbyak::Zmm>::transpose_f32(
|
|
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= rows_step && ncolumns >= 0
|
|
&& ncolumns <= columns_step);
|
|
if (!nrows) return;
|
|
Label transpose_f32_done;
|
|
const bool dynamic_column_size = ncolumns == 0 && is_dynamic_src_ld;
|
|
auto kmovw = [this, dynamic_column_size](
|
|
Opmask k, size_t q, bool load_mask = false) {
|
|
if (dynamic_column_size && load_mask) {
|
|
// reg_opmask_shift_compute is rcx, and we need cl for the shift
|
|
mov(reg_opmask_shift_compute, reg_loop_m);
|
|
mov(regq_tmp, 1);
|
|
shl(regq_tmp, cl);
|
|
sub(regq_tmp, 1);
|
|
} else
|
|
mov(regw_tmp, q);
|
|
jit_generator_t::kmovw(k, regw_tmp);
|
|
};
|
|
|
|
const int load_mask
|
|
= ncolumns < columns_step ? (1 << ncolumns) - 1 : 0xffff;
|
|
kmovw(kTail, load_mask, true);
|
|
|
|
auto src_zmm = [](int i) {
|
|
assert(i >= 0 && i < 16);
|
|
return Zmm(i);
|
|
};
|
|
|
|
auto tmp_zmm = [](int i) {
|
|
assert(i >= 0 && i < 16);
|
|
return Zmm(16 + i);
|
|
};
|
|
|
|
auto load = [this, src, nrows, src_zmm](int i) {
|
|
const auto addr = is_dynamic_src_ld
|
|
? ptr[i % 2 == 0 ? reg_aux_src0 : reg_aux_src1]
|
|
: EVEX_compress_addr(src, i * src_stride);
|
|
if (i < nrows) {
|
|
if (use_fp16_instructions_)
|
|
vcvtph2psx(src_zmm(i) | kTail | T_z, addr);
|
|
else
|
|
vmovups(src_zmm(i) | kTail | T_z, addr);
|
|
} else {
|
|
vpxord(src_zmm(i), src_zmm(i), src_zmm(i));
|
|
}
|
|
};
|
|
|
|
auto store = [this, dst](Zmm r, int i) {
|
|
auto addr = EVEX_compress_addr(dst, i * dst_stride);
|
|
vmovups(addr, r | kTail);
|
|
};
|
|
|
|
auto transpose16x8 = [&](int base_idx) {
|
|
assert(base_idx == 0 || base_idx == 8);
|
|
|
|
// swap 1
|
|
for (int i = 0; i < 4; i++) {
|
|
int src_idx0 = base_idx + i * 2;
|
|
int src_idx1 = src_idx0 + 1;
|
|
|
|
int next_src_idx0 = src_idx0 + 2;
|
|
int next_src_idx1 = src_idx1 + 2;
|
|
bool load_next = base_idx == 0 || i < 3;
|
|
|
|
if (base_idx == 0 && i == 0) {
|
|
if (is_dynamic_src_ld) {
|
|
mov(reg_aux_src0, src);
|
|
mov(reg_aux_src1, src);
|
|
add(reg_aux_src1, ptr[rsp + dynamic_src_ld_offt_]);
|
|
}
|
|
load(src_idx0);
|
|
load(src_idx1);
|
|
}
|
|
|
|
auto tmp0 = tmp_zmm(src_idx0);
|
|
auto tmp1 = tmp_zmm(src_idx1);
|
|
auto src0 = src_zmm(src_idx0);
|
|
auto src1 = src_zmm(src_idx1);
|
|
|
|
if (next_src_idx0 < nrows && load_next) {
|
|
if (is_dynamic_src_ld)
|
|
add(reg_aux_src0, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
load(next_src_idx0);
|
|
}
|
|
valignd(tmp0, src0, src0, 0x1);
|
|
|
|
if (next_src_idx1 < nrows && load_next) {
|
|
if (is_dynamic_src_ld)
|
|
add(reg_aux_src1, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
load(next_src_idx1);
|
|
}
|
|
valignd(tmp1, src1, src1, 0xf);
|
|
|
|
vmovaps(src0 | kAAAA, tmp1);
|
|
vmovaps(src1 | k5555, tmp0);
|
|
}
|
|
|
|
// swap 2
|
|
for (int i = 0; i < 4; i++) {
|
|
int select_half = (i < 2) ? 0 : 2;
|
|
int src_idx0 = base_idx + i + select_half + 0;
|
|
int src_idx2 = src_idx0 + 2;
|
|
|
|
auto tmp0 = tmp_zmm(src_idx0);
|
|
auto tmp1 = tmp_zmm(src_idx2);
|
|
auto src0 = src_zmm(src_idx0);
|
|
auto src2 = src_zmm(src_idx2);
|
|
|
|
valignd(tmp0, src0, src0, 0x2);
|
|
valignd(tmp1, src2, src2, 0xe);
|
|
vmovaps(src2 | k3333, tmp0);
|
|
vmovaps(src0 | kCCCC, tmp1);
|
|
}
|
|
|
|
// swap 4
|
|
for (int i = 0; i < 4; i++) {
|
|
int src_idx0 = base_idx + i;
|
|
int src_idx4 = src_idx0 + 4;
|
|
|
|
auto tmp0 = tmp_zmm(src_idx0);
|
|
auto src0 = src_zmm(src_idx0);
|
|
auto src4 = src_zmm(src_idx4);
|
|
|
|
vmovaps(tmp0, src0);
|
|
vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
|
|
vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
|
|
}
|
|
};
|
|
|
|
auto fixup16x16 = [&]() {
|
|
const int store_mask = nrows < rows_step ? (1 << nrows) - 1 : 0xffff;
|
|
kmovw(kTail, store_mask);
|
|
|
|
// swap 8
|
|
int columns_to_store = dynamic_column_size ? 8 : nstl::min(8, ncolumns);
|
|
for (int i = 0; i < columns_to_store; i++) {
|
|
auto tmp = tmp_zmm(i);
|
|
auto src0 = src_zmm(i);
|
|
auto src8 = src_zmm(8 + i);
|
|
vshuff64x2(tmp, src0, src8, 0x44);
|
|
store(tmp, i);
|
|
if (dynamic_column_size) {
|
|
dec(reg_opmask_shift_compute);
|
|
jz(transpose_f32_done, T_NEAR);
|
|
}
|
|
}
|
|
|
|
columns_to_store = dynamic_column_size ? 8 : nstl::max(0, ncolumns - 8);
|
|
for (int i = 0; i < columns_to_store; i++) {
|
|
auto tmp = tmp_zmm(8 + i);
|
|
auto src0 = src_zmm(i);
|
|
auto src8 = src_zmm(8 + i);
|
|
vshuff64x2(tmp, src0, src8, 0xee);
|
|
store(tmp, 8 + i);
|
|
if (dynamic_column_size) {
|
|
dec(reg_opmask_shift_compute);
|
|
jz(transpose_f32_done, T_NEAR);
|
|
}
|
|
}
|
|
};
|
|
|
|
transpose16x8(0);
|
|
transpose16x8(8);
|
|
fixup16x16();
|
|
L(transpose_f32_done);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::deploy_transpose(
|
|
reg64_t dst, reg64_t src, int nrows, int ncolumns) {
|
|
if (is_f32 || use_fp16_instructions_)
|
|
transpose_f32(dst, src, nrows, ncolumns);
|
|
else
|
|
transpose_bf16(dst, src, nrows, ncolumns);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::init_masks() {
|
|
alignas(64) static constexpr const int64_t idx1[8]
|
|
= {2, 3, 0, 1, 6, 7, 4, 5};
|
|
alignas(64) static constexpr const int64_t idx2[8]
|
|
= {1, 0, 3, 2, 5, 4, 7, 6};
|
|
alignas(64) static constexpr const int32_t idx3[16]
|
|
= {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
|
|
alignas(64) static constexpr const int32_t idx4[16]
|
|
= {8, 10, 12, 14, 0, 2, 4, 6, 9, 11, 13, 15, 1, 3, 5, 7};
|
|
alignas(64) static constexpr const uint16_t idx5[32]
|
|
= {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1, 17,
|
|
3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
|
|
if (is_superset(conf_->isa, avx512_core)) {
|
|
if (is_f32) {
|
|
kmovw(k3333, 0x3333); // 0011001100110011
|
|
kmovw(k5555, 0x5555); // 0101010101010101
|
|
kmovw(kAAAA, 0xaaaa); // 1010101010101010
|
|
kmovw(kCCCC, 0xcccc); // 1100110011001100
|
|
kmovw(k0F0F, 0x0f0f); // 0000111100001111
|
|
kmovw(kF0F0, 0xf0f0); // 1111000011110000
|
|
} else {
|
|
kmovw(kFFFF, 0xffff);
|
|
kmovw(k5555, 0x5555);
|
|
kmovw(kAAAA, 0xaaaa);
|
|
kmovw(kAA, 0xaa);
|
|
kmovw(k55, 0x55);
|
|
kmovw(kCC, 0xcc);
|
|
kmovw(k33, 0x33);
|
|
}
|
|
if (!is_f32) {
|
|
vmovdqa64(vidx1, idx1);
|
|
vmovdqa64(vidx2, idx2);
|
|
vmovdqa32(vidx3, idx3);
|
|
vmovdqa32(vidx4, idx4);
|
|
vmovdqa32(vidx5, (const int32_t *)idx5);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_a_transposed_impl_t<Vmm>::generate() {
|
|
|
|
// only bf16, f16 and f32 supported for now
|
|
if (!one_of(conf_->src_dt, data_type::bf16, data_type::f32, data_type::f16))
|
|
return;
|
|
preamble();
|
|
sub(rsp, stack_space_needed_);
|
|
mov(regq_tmp, ptr[param1 + GET_OFF(current_M_blk)]);
|
|
mov(ptr[rsp + current_M_blk_offt_], regq_tmp);
|
|
mov(regq_tmp, ptr[param1 + GET_OFF(src)]);
|
|
mov(ptr[rsp + src_offt_], regq_tmp);
|
|
mov(regq_tmp, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(ptr[rsp + tr_src_offt_], regq_tmp);
|
|
mov(regq_tmp, ptr[param1 + GET_OFF(current_K_blk)]);
|
|
mov(ptr[rsp + current_K_blk_offt_], regq_tmp);
|
|
if (is_dynamic_src_ld) {
|
|
// dynamic src_stride
|
|
mov(regq_tmp, ptr[param1 + GET_OFF(dynamic_src_ld)]);
|
|
mov(ptr[rsp + dynamic_src_ld_offt_], regq_tmp);
|
|
|
|
// src_stride * 2
|
|
shl(regq_tmp, 1);
|
|
mov(ptr[rsp + dynamic_src_ld_x_2_offt_], regq_tmp);
|
|
|
|
// src_stride * rows_step
|
|
assert(rows_step == 16);
|
|
shl(regq_tmp, 3);
|
|
mov(ptr[rsp + dynamic_src_ld_x_kstep_offt_], regq_tmp);
|
|
}
|
|
|
|
init_masks();
|
|
|
|
const int k_block_tail = conf_->K_blk % rows_step;
|
|
const int last_k_block_tail = (conf_->K % conf_->K_blk) % rows_step;
|
|
const int m_block_tail = conf_->M_blk % columns_step;
|
|
const int last_m_block_tail = conf_->M_tail % columns_step;
|
|
|
|
auto compute_m_loop = [&](reg64_t ®_base, reg64_t ®_tr_base,
|
|
int nrows) {
|
|
mov(reg_loop_m, ptr[rsp + current_M_blk_offt_]);
|
|
mov(reg_m_src, reg_base);
|
|
mov(reg_m_dst, reg_tr_base);
|
|
|
|
Label m_loop_tail_or_done, m_loop, compute_m_loop_done;
|
|
cmp(reg_loop_m, columns_step);
|
|
jl(m_loop_tail_or_done, T_NEAR);
|
|
|
|
L(m_loop);
|
|
{
|
|
deploy_transpose(reg_m_dst, reg_m_src, nrows, columns_step);
|
|
add(reg_m_src, m_loop_src_shift);
|
|
add(reg_m_dst, m_loop_dst_shift);
|
|
}
|
|
sub(reg_loop_m, columns_step);
|
|
cmp(reg_loop_m, columns_step);
|
|
jge(m_loop, T_NEAR);
|
|
|
|
if (m_block_tail > 0 || last_m_block_tail > 0 || is_dynamic_src_ld)
|
|
jz(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_loop_tail_or_done);
|
|
|
|
if (m_block_tail > 0) {
|
|
Label m_block_tail_done;
|
|
cmp(reg_loop_m, m_block_tail);
|
|
jne(m_block_tail_done, T_NEAR);
|
|
|
|
deploy_transpose(reg_m_dst, reg_m_src, nrows, m_block_tail);
|
|
jmp(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_block_tail_done);
|
|
}
|
|
if (IMPLICATION(
|
|
last_m_block_tail <= 0 || last_m_block_tail == m_block_tail,
|
|
is_dynamic_src_ld)) {
|
|
Label last_m_block_tail_done;
|
|
cmp(reg_loop_m, 0);
|
|
jle(last_m_block_tail_done, T_NEAR);
|
|
|
|
deploy_transpose(reg_m_dst, reg_m_src, nrows,
|
|
is_dynamic_src_ld ? 0 : last_m_block_tail);
|
|
|
|
L(last_m_block_tail_done);
|
|
}
|
|
|
|
L(compute_m_loop_done);
|
|
};
|
|
|
|
auto compute_k_loop = [&]() {
|
|
mov(reg_k_src, ptr[rsp + src_offt_]);
|
|
mov(reg_k_dst, ptr[rsp + tr_src_offt_]);
|
|
mov(reg_loop_k, ptr[rsp + current_K_blk_offt_]);
|
|
|
|
Label k_tail_or_done, k_loop, compute_k_loop_done;
|
|
cmp(reg_loop_k, rows_step);
|
|
jl(k_tail_or_done, T_NEAR);
|
|
|
|
L(k_loop);
|
|
{
|
|
compute_m_loop(reg_k_src, reg_k_dst, rows_step);
|
|
if (is_dynamic_src_ld)
|
|
add(reg_k_src, ptr[rsp + dynamic_src_ld_x_kstep_offt_]);
|
|
else
|
|
add(reg_k_src, k_loop_src_shift);
|
|
add(reg_k_dst, k_loop_dst_shift);
|
|
}
|
|
sub(reg_loop_k, rows_step);
|
|
cmp(reg_loop_k, rows_step);
|
|
jge(k_loop, T_NEAR);
|
|
|
|
if (k_block_tail > 0 || last_k_block_tail > 0)
|
|
jz(compute_k_loop_done, T_NEAR);
|
|
|
|
L(k_tail_or_done);
|
|
|
|
if (k_block_tail > 0) {
|
|
Label k_block_tail_done;
|
|
cmp(reg_loop_k, k_block_tail);
|
|
jne(k_block_tail_done, T_NEAR);
|
|
|
|
compute_m_loop(reg_k_src, reg_k_dst, k_block_tail);
|
|
jmp(compute_k_loop_done, T_NEAR);
|
|
|
|
L(k_block_tail_done);
|
|
}
|
|
if (last_k_block_tail > 0 && last_k_block_tail != k_block_tail) {
|
|
Label last_k_block_tail_done;
|
|
cmp(reg_loop_k, last_k_block_tail);
|
|
jne(last_k_block_tail_done, T_NEAR);
|
|
|
|
compute_m_loop(reg_k_src, reg_k_dst, last_k_block_tail);
|
|
jmp(compute_k_loop_done, T_NEAR);
|
|
|
|
L(last_k_block_tail_done);
|
|
}
|
|
|
|
L(compute_k_loop_done);
|
|
};
|
|
|
|
compute_k_loop();
|
|
|
|
add(rsp, stack_space_needed_);
|
|
postamble();
|
|
}
|
|
|
|
struct jit_brgemm_matmul_copy_a_transposed_int8_impl_t
|
|
: public jit_brgemm_matmul_copy_a_t,
|
|
public jit_generator_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(
|
|
jit_brgemm_matmul_copy_a_transposed_int8_impl_t)
|
|
|
|
jit_brgemm_matmul_copy_a_transposed_int8_impl_t(
|
|
const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_a_t(conf)
|
|
, jit_generator_t(jit_name())
|
|
, src_stride_(conf_->copy_A_src_stride)
|
|
, dst_stride_(conf_->LDA * conf_->tr_a_dt_sz)
|
|
, m_loop_src_shift_(columns_step_ * conf_->a_dt_sz)
|
|
, m_loop_dst_shift_(columns_step_ * dst_stride_)
|
|
, k_loop_src_shift_(rows_step_ * src_stride_)
|
|
, k_loop_dst_shift_(rows_step_ * conf_->tr_a_dt_sz)
|
|
, has_vpermb_(cpu().has(Xbyak::util::Cpu::tAVX512_VBMI))
|
|
, is_dynamic_src_ld_(conf_->is_runtime_M)
|
|
, k_block_tail_(conf_->K_blk % rows_step_)
|
|
, last_k_block_tail_((conf_->K % conf_->K_blk) % rows_step_)
|
|
, m_block_tail_(conf_->M_blk % columns_step_)
|
|
, last_m_block_tail_(conf_->M_tail % columns_step_)
|
|
, do_compute_compensation_(conf_->has_zero_point_b) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
constexpr static int rows_step_ = 16;
|
|
constexpr static int columns_step_ = 16;
|
|
constexpr static int current_M_blk_offt_ = 0;
|
|
constexpr static int current_K_blk_offt_ = 8;
|
|
constexpr static int src_offt_ = 16;
|
|
constexpr static int tr_src_offt_ = 24;
|
|
constexpr static int dynamic_src_ld_offt_ = 32;
|
|
constexpr static int dynamic_src_ld_x_2_offt_ = 40;
|
|
constexpr static int dynamic_src_ld_x_kstep_offt_ = 48;
|
|
constexpr static int stack_space_needed_ = 56;
|
|
|
|
const dim_t src_stride_, dst_stride_;
|
|
const dim_t m_loop_src_shift_, m_loop_dst_shift_;
|
|
const dim_t k_loop_src_shift_, k_loop_dst_shift_;
|
|
const bool has_vpermb_;
|
|
const bool is_dynamic_src_ld_;
|
|
|
|
const int k_block_tail_, last_k_block_tail_;
|
|
const int m_block_tail_, last_m_block_tail_;
|
|
|
|
const bool do_compute_compensation_;
|
|
|
|
Opmask kFFFF_ = k1;
|
|
Opmask k5555_ = k2;
|
|
Opmask kAAAA_ = k3;
|
|
Opmask kAA_ = k4;
|
|
Opmask k55_ = k5;
|
|
Opmask kCC_ = k6;
|
|
Opmask k33_ = k7;
|
|
Opmask kTail_ = k1;
|
|
|
|
Reg64 reg_tmp_ = r15;
|
|
Reg64 reg_k_src_ = r14;
|
|
Reg64 reg_k_dst_ = r13;
|
|
Reg64 reg_m_src_ = r12;
|
|
Reg64 reg_m_dst_ = r11;
|
|
Reg64 reg_aux_src0_ = r10;
|
|
Reg64 reg_aux_src1_ = r9;
|
|
Reg64 reg_zp_comp_res_ptr_ = r8;
|
|
Reg64 reg_zp_comp_prev_data_ = rdx;
|
|
Reg64 reg_loop_m_ = rbx;
|
|
Reg64 reg_loop_k_ = rax;
|
|
// Note: this must be assigned to rcx as it's used in shl instruction,
|
|
// clashes with abi_param1 on Windows OS
|
|
Reg64 reg_opmask_shift_compute_ = rcx;
|
|
|
|
// Indices used in permutations
|
|
Zmm zmm_idx_1_ = zmm31;
|
|
Zmm zmm_idx_2_ = zmm30;
|
|
Zmm zmm_idx_3_ = zmm29;
|
|
Zmm zmm_idx_4_ = zmm28;
|
|
|
|
// zmm_idx_5_ is used in vpermb implementation only
|
|
// zmm_conversion_tmp_ is used in vpermw implementation only
|
|
Zmm zmm_idx_5_ = zmm27;
|
|
Zmm zmm_conversion_tmp_ = zmm27;
|
|
|
|
// Required for zero-point
|
|
Zmm zmm_comp_temp_ = zmm26;
|
|
Zmm zmm_comp_acc_ = zmm25;
|
|
|
|
Zmm get_zmm_src(int i) {
|
|
assert(i >= 0 && i < columns_step_);
|
|
return Zmm(i);
|
|
}
|
|
void kmovd(const bool dynamic_column_size, Opmask k, unsigned w,
|
|
bool load_mask_stage = false) {
|
|
if (dynamic_column_size && load_mask_stage) {
|
|
// reg_opmask_shift_compute_ is rcx, and we need cl for the shift
|
|
mov(reg_opmask_shift_compute_, reg_loop_m_);
|
|
mov(reg_tmp_, 1);
|
|
shl(reg_tmp_, cl);
|
|
sub(reg_tmp_, 1);
|
|
} else
|
|
mov(reg_tmp_, w);
|
|
jit_generator_t::kmovd(k, reg_tmp_.cvt32());
|
|
}
|
|
|
|
void transpose_int8_vpermb(Reg64 dst, Reg64 src, int nrows, int ncolumns);
|
|
void transpose_int8_vpermw(Reg64 dst, Reg64 src, int nrows, int ncolumns);
|
|
|
|
inline void deploy_transpose(
|
|
Reg64 dst, Reg64 src, int nrows, int ncolumns) {
|
|
if (has_vpermb_)
|
|
transpose_int8_vpermb(dst, src, nrows, ncolumns);
|
|
else
|
|
transpose_int8_vpermw(dst, src, nrows, ncolumns);
|
|
}
|
|
|
|
void reset_compensation_accumulator() {
|
|
if (do_compute_compensation_)
|
|
uni_vpxor(zmm_comp_acc_, zmm_comp_acc_, zmm_comp_acc_);
|
|
}
|
|
void accumulate_compensation(Zmm zmm_copy) {
|
|
if (do_compute_compensation_) {
|
|
if (conf_->src_dt == data_type::s8)
|
|
vpmovsxbd(zmm_comp_temp_, zmm_copy);
|
|
else
|
|
vpmovzxbd(zmm_comp_temp_, zmm_copy);
|
|
vpaddd(zmm_comp_acc_, zmm_comp_acc_, zmm_comp_temp_);
|
|
}
|
|
}
|
|
void save_partial_compensation() {
|
|
if (do_compute_compensation_) {
|
|
Label no_previous_data;
|
|
test(reg_zp_comp_prev_data_, reg_zp_comp_prev_data_);
|
|
je(no_previous_data);
|
|
vpaddd(zmm_comp_acc_, zmm_comp_acc_, ptr[reg_zp_comp_res_ptr_]);
|
|
L(no_previous_data);
|
|
|
|
vmovups(ptr[reg_zp_comp_res_ptr_], zmm_comp_acc_);
|
|
add(reg_zp_comp_res_ptr_, sizeof(int32_t) * 16);
|
|
}
|
|
}
|
|
|
|
void compute_m_loop(int nrows);
|
|
void compute_k_loop(bool is_first_K_iter, bool is_last_K_iter);
|
|
void generate() override;
|
|
};
|
|
|
|
void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::transpose_int8_vpermb(
|
|
Reg64 dst, Reg64 src, int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= rows_step_ && ncolumns >= 0
|
|
&& ncolumns <= columns_step_);
|
|
if (!nrows) return;
|
|
|
|
auto load = [this, src](Zmm r, int i, Reg64 reg_aux) {
|
|
auto addr = is_dynamic_src_ld_
|
|
? ptr[reg_aux]
|
|
: EVEX_compress_addr(src, i * src_stride_);
|
|
vmovdqu8(r | kFFFF_ | T_z, addr);
|
|
accumulate_compensation(r);
|
|
};
|
|
|
|
auto store = [this, dst](Zmm r, int i) {
|
|
auto addr = EVEX_compress_addr(dst, i * dst_stride_);
|
|
vmovdqu8(addr, r | kTail_);
|
|
};
|
|
|
|
Label transpose_int8_done;
|
|
|
|
reset_compensation_accumulator();
|
|
|
|
const bool dynamic_column_size = ncolumns == 0 && is_dynamic_src_ld_;
|
|
const int load_mask
|
|
= ncolumns < columns_step_ ? (1 << ncolumns) - 1 : 0xffff;
|
|
kmovd(dynamic_column_size, kFFFF_, load_mask, true);
|
|
|
|
// load rows and swap bytes
|
|
for (int i = 0; i < nrows; i += 4) {
|
|
auto idx0 = i;
|
|
auto zmm_src0 = get_zmm_src(idx0);
|
|
if (is_dynamic_src_ld_) {
|
|
if (i == 0)
|
|
mov(reg_aux_src0_, src);
|
|
else
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_offt_]);
|
|
}
|
|
load(zmm_src0, idx0, reg_aux_src0_);
|
|
|
|
auto idx1 = i + 1;
|
|
auto zmm_src1 = get_zmm_src(idx1);
|
|
if (is_dynamic_src_ld_)
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_offt_]);
|
|
if (idx1 < nrows)
|
|
load(zmm_src1, idx1, reg_aux_src0_);
|
|
else
|
|
vpxord(zmm_src1, zmm_src1, zmm_src1);
|
|
|
|
auto idx2 = i + 2;
|
|
auto zmm_src2 = get_zmm_src(idx2);
|
|
if (is_dynamic_src_ld_)
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_offt_]);
|
|
if (idx2 < nrows)
|
|
load(zmm_src2, idx2, reg_aux_src0_);
|
|
else
|
|
vpxord(zmm_src2, zmm_src2, zmm_src2);
|
|
|
|
auto idx3 = i + 3;
|
|
auto zmm_src3 = get_zmm_src(idx3);
|
|
if (is_dynamic_src_ld_)
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_offt_]);
|
|
if (idx3 < nrows)
|
|
load(zmm_src3, idx3, reg_aux_src0_);
|
|
else
|
|
vpxord(zmm_src3, zmm_src3, zmm_src3);
|
|
|
|
// concatenate 4 rows
|
|
vinserti64x2(Ymm(zmm_src0.getIdx()), Ymm(zmm_src0.getIdx()),
|
|
Xmm(zmm_src1.getIdx()), 1);
|
|
vinserti64x2(Ymm(zmm_src2.getIdx()), Ymm(zmm_src2.getIdx()),
|
|
Xmm(zmm_src3.getIdx()), 1);
|
|
vinserti64x4(zmm_src0, zmm_src0, Ymm(zmm_src2.getIdx()), 1);
|
|
|
|
// swap bytes
|
|
vpermb(zmm_src0, zmm_idx_1_, zmm_src0);
|
|
}
|
|
|
|
// swap doubles
|
|
for (int i = 0; i < 2; i++) {
|
|
auto idx0 = 8 * i;
|
|
auto idx1 = idx0 + 4;
|
|
|
|
auto zmm_src0 = get_zmm_src(idx0);
|
|
auto zmm_src1 = get_zmm_src(idx1);
|
|
|
|
auto zmm_tmp0 = get_zmm_src(idx0 + 1);
|
|
auto zmm_tmp1 = get_zmm_src(idx1 + 1);
|
|
|
|
vmovups(zmm_tmp0, zmm_idx_2_);
|
|
vmovups(zmm_tmp1, zmm_idx_3_);
|
|
|
|
vpermi2d(zmm_tmp0, zmm_src0, zmm_src1);
|
|
vpermi2d(zmm_tmp1, zmm_src0, zmm_src1);
|
|
}
|
|
|
|
// swap quads
|
|
for (int i = 0; i < 2; i++) {
|
|
auto idx0 = 4 * i;
|
|
auto idx1 = idx0 + 8;
|
|
|
|
auto zmm_src0 = get_zmm_src(idx0 + 1);
|
|
auto zmm_src1 = get_zmm_src(idx1 + 1);
|
|
|
|
auto zmm_tmp0 = get_zmm_src(idx0);
|
|
auto zmm_tmp1 = get_zmm_src(idx1);
|
|
|
|
vmovups(zmm_tmp0, zmm_idx_4_);
|
|
vmovups(zmm_tmp1, zmm_idx_5_);
|
|
|
|
vpermi2q(zmm_tmp0, zmm_src0, zmm_src1);
|
|
vpermi2q(zmm_tmp1, zmm_src0, zmm_src1);
|
|
}
|
|
|
|
// extract columns
|
|
for (int i = 0; i < 16; i += 4) {
|
|
vextracti64x4(
|
|
Ymm(get_zmm_src(i + 2).getIdx()) | T_z, get_zmm_src(i), 1);
|
|
vextracti32x4(
|
|
Xmm(get_zmm_src(i + 1).getIdx()) | T_z, get_zmm_src(i), 1);
|
|
vextracti32x4(Xmm(get_zmm_src(i + 3).getIdx()) | T_z,
|
|
Ymm(get_zmm_src(i + 2).getIdx()), 1);
|
|
}
|
|
|
|
// store columns
|
|
const int rows_to_store = rnd_up(nrows, 2);
|
|
const int store_mask
|
|
= rows_to_store < rows_step_ ? (1 << rows_to_store) - 1 : 0xffff;
|
|
kmovd(dynamic_column_size, kTail_, store_mask);
|
|
|
|
auto get_vec_idx = [](int col_idx) {
|
|
assert(col_idx < columns_step_ && col_idx >= 0);
|
|
|
|
const auto div = col_idx / 4;
|
|
const auto mod = col_idx % 4;
|
|
|
|
return mod * 4 + div;
|
|
};
|
|
|
|
const int columns_to_store = dynamic_column_size ? columns_step_ : ncolumns;
|
|
for (int col_idx = 0; col_idx < columns_to_store; col_idx++) {
|
|
store(get_zmm_src(get_vec_idx(col_idx)), col_idx);
|
|
if (dynamic_column_size) {
|
|
dec(reg_opmask_shift_compute_);
|
|
jz(transpose_int8_done, T_NEAR);
|
|
}
|
|
}
|
|
|
|
L(transpose_int8_done);
|
|
|
|
save_partial_compensation();
|
|
}
|
|
|
|
void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::transpose_int8_vpermw(
|
|
Reg64 dst, Reg64 src, int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= rows_step_ && ncolumns >= 0
|
|
&& ncolumns <= columns_step_);
|
|
if (!nrows) return;
|
|
|
|
auto load = [this, src](Zmm r, int i, Reg64 reg_aux) {
|
|
auto addr = is_dynamic_src_ld_
|
|
? ptr[reg_aux]
|
|
: EVEX_compress_addr(src, i * src_stride_);
|
|
vmovdqu8(zmm_conversion_tmp_ | kFFFF_ | T_z, addr);
|
|
accumulate_compensation(zmm_conversion_tmp_);
|
|
if (conf_->src_dt == data_type::s8)
|
|
vpmovsxbw(r, zmm_conversion_tmp_);
|
|
else
|
|
vpmovzxbw(r, zmm_conversion_tmp_);
|
|
};
|
|
|
|
auto store = [this, dst](Zmm r, int i) {
|
|
if (conf_->src_dt == data_type::s8)
|
|
vpmovswb(Ymm(zmm_conversion_tmp_.getIdx()), r);
|
|
else
|
|
vpmovuswb(Ymm(zmm_conversion_tmp_.getIdx()), r);
|
|
auto addr = EVEX_compress_addr(dst, i * dst_stride_);
|
|
vmovdqu8(addr, zmm_conversion_tmp_ | kTail_);
|
|
};
|
|
|
|
Label transpose_int8_done;
|
|
|
|
reset_compensation_accumulator();
|
|
|
|
const bool dynamic_column_size = ncolumns == 0 && is_dynamic_src_ld_;
|
|
const int load_mask
|
|
= ncolumns < columns_step_ ? (1 << ncolumns) - 1 : 0xffff;
|
|
kmovd(dynamic_column_size, kFFFF_, load_mask, true);
|
|
|
|
for (int i = 0; i < nrows / 2; i++) {
|
|
auto idx0 = 2 * i;
|
|
auto idx1 = 2 * i + 1;
|
|
auto zmm_src0 = get_zmm_src(idx0);
|
|
auto zmm_src1 = get_zmm_src(idx1);
|
|
if (is_dynamic_src_ld_) {
|
|
if (i == 0) {
|
|
mov(reg_aux_src0_, src);
|
|
mov(reg_aux_src1_, src);
|
|
add(reg_aux_src1_, ptr[rsp + dynamic_src_ld_offt_]);
|
|
} else {
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
add(reg_aux_src1_, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
}
|
|
}
|
|
|
|
load(zmm_src0, idx0, reg_aux_src0_);
|
|
load(zmm_src1, idx1, reg_aux_src1_);
|
|
|
|
vinserti64x4(zmm_src0, zmm_src0, Ymm(zmm_src1.getIdx()), 1);
|
|
vpermw(zmm_src0, zmm_idx_1_, zmm_src0);
|
|
}
|
|
|
|
// for odd numbers we need to mix row with zeroes
|
|
if (nrows % 2) {
|
|
int i = nrows / 2;
|
|
auto zmm_src0 = get_zmm_src(2 * i);
|
|
if (is_dynamic_src_ld_) {
|
|
if (i == 0)
|
|
mov(reg_aux_src0_, src);
|
|
else
|
|
add(reg_aux_src0_, ptr[rsp + dynamic_src_ld_x_2_offt_]);
|
|
}
|
|
|
|
load(zmm_src0, 2 * i, reg_aux_src0_);
|
|
|
|
vpermw(zmm_src0, zmm_idx_1_, zmm_src0);
|
|
}
|
|
|
|
for (int i = rnd_up(nrows, 2); i < rows_step_; i += 2)
|
|
vpxord(get_zmm_src(i), get_zmm_src(i), get_zmm_src(i));
|
|
|
|
// swap 1
|
|
for (int i = 0; i < 4; i++) {
|
|
auto zmm0 = get_zmm_src(4 * i);
|
|
auto zmm1 = get_zmm_src(4 * i + 2);
|
|
auto tmp0 = get_zmm_src(4 * i + 1);
|
|
auto tmp1 = get_zmm_src(4 * i + 3);
|
|
|
|
vmovups(tmp0, zmm0);
|
|
vmovups(tmp1, zmm1);
|
|
|
|
vpermps(tmp0 | kAAAA_, zmm_idx_2_, zmm1);
|
|
vpermps(tmp1 | k5555_, zmm_idx_2_, zmm0);
|
|
}
|
|
|
|
// swap 2
|
|
int base_idx;
|
|
base_idx = 0;
|
|
for (int i = 0; i < 2; i++) {
|
|
auto zmm0 = get_zmm_src(base_idx + 2 * i + 1);
|
|
auto zmm1 = get_zmm_src(base_idx + 2 * i + 5);
|
|
|
|
auto tmp0 = get_zmm_src(base_idx + 2 * i);
|
|
auto tmp1 = get_zmm_src(base_idx + 2 * i + 4);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kAA_, zmm_idx_3_, zmm1);
|
|
vpermpd(tmp1 | k55_, zmm_idx_3_, zmm0);
|
|
}
|
|
base_idx = 8;
|
|
for (int i = 0; i < 2; i++) {
|
|
auto zmm0 = get_zmm_src(base_idx + 2 * i + 1);
|
|
auto zmm1 = get_zmm_src(base_idx + 2 * i + 5);
|
|
|
|
auto tmp0 = get_zmm_src(base_idx + 2 * i);
|
|
auto tmp1 = get_zmm_src(base_idx + 2 * i + 4);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kAA_, zmm_idx_3_, zmm1);
|
|
vpermpd(tmp1 | k55_, zmm_idx_3_, zmm0);
|
|
}
|
|
|
|
// swap 3
|
|
for (int i = 0; i < 4; i++) {
|
|
auto zmm0 = get_zmm_src(2 * i);
|
|
auto zmm1 = get_zmm_src(2 * i + 8);
|
|
|
|
auto tmp0 = get_zmm_src(2 * i + 1);
|
|
auto tmp1 = get_zmm_src(2 * i + 9);
|
|
|
|
vmovupd(tmp0, zmm0);
|
|
vmovupd(tmp1, zmm1);
|
|
|
|
vpermpd(tmp0 | kCC_, zmm_idx_4_, zmm1);
|
|
vpermpd(tmp1 | k33_, zmm_idx_4_, zmm0);
|
|
}
|
|
|
|
// all stores
|
|
for (int i = 0; i < 8; i++)
|
|
vextracti64x4(Ymm(get_zmm_src(2 * i).getIdx()) | T_z,
|
|
get_zmm_src(2 * i + 1), 1);
|
|
|
|
const int rows_to_store = rnd_up(nrows, 2);
|
|
const int store_mask
|
|
= rows_to_store < rows_step_ ? (1 << rows_to_store) - 1 : 0xffff;
|
|
kmovd(dynamic_column_size, kTail_, store_mask);
|
|
|
|
auto get_vec_idx = [](int col_idx) {
|
|
assert(col_idx < columns_step_ && col_idx >= 0);
|
|
const int blk_sz = 4;
|
|
const int blk_idx = col_idx / blk_sz;
|
|
const int idx_within_blk = col_idx % blk_sz;
|
|
|
|
// 0 1 2 3 -> 0 2 1 3
|
|
const int mapped_blk_idx = 2 * blk_idx - (blk_idx / 2) * 3;
|
|
// 0 1 2 3 -> 1 0 3 2
|
|
const int mapped_idx_within_blk
|
|
= idx_within_blk + 1 - 2 * (idx_within_blk % 2);
|
|
return blk_sz * mapped_blk_idx + mapped_idx_within_blk;
|
|
};
|
|
|
|
const int columns_to_store = dynamic_column_size ? columns_step_ : ncolumns;
|
|
for (int col_idx = 0; col_idx < columns_to_store; col_idx++) {
|
|
store(get_zmm_src(get_vec_idx(col_idx)), col_idx);
|
|
if (dynamic_column_size) {
|
|
dec(reg_opmask_shift_compute_);
|
|
jz(transpose_int8_done, T_NEAR);
|
|
}
|
|
}
|
|
|
|
L(transpose_int8_done);
|
|
|
|
save_partial_compensation();
|
|
}
|
|
|
|
void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::compute_m_loop(
|
|
int nrows) {
|
|
mov(reg_loop_m_, ptr[rsp + current_M_blk_offt_]);
|
|
mov(reg_m_src_, reg_k_src_);
|
|
mov(reg_m_dst_, reg_k_dst_);
|
|
|
|
if (do_compute_compensation_)
|
|
mov(reg_zp_comp_res_ptr_,
|
|
ptr[param1 + GET_OFF(zp_a_compensation_result_ptr)]);
|
|
|
|
Label m_loop_tail_or_done, m_loop, compute_m_loop_done;
|
|
cmp(reg_loop_m_, columns_step_);
|
|
jl(m_loop_tail_or_done, T_NEAR);
|
|
|
|
L(m_loop);
|
|
{
|
|
deploy_transpose(reg_m_dst_, reg_m_src_, nrows, columns_step_);
|
|
add(reg_m_src_, m_loop_src_shift_);
|
|
add(reg_m_dst_, m_loop_dst_shift_);
|
|
}
|
|
sub(reg_loop_m_, columns_step_);
|
|
cmp(reg_loop_m_, columns_step_);
|
|
jge(m_loop, T_NEAR);
|
|
|
|
if (m_block_tail_ > 0 || last_m_block_tail_ > 0 || is_dynamic_src_ld_)
|
|
jz(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_loop_tail_or_done);
|
|
|
|
if (m_block_tail_ > 0) {
|
|
Label m_block_tail_done;
|
|
cmp(reg_loop_m_, m_block_tail_);
|
|
jne(m_block_tail_done, T_NEAR);
|
|
|
|
deploy_transpose(reg_m_dst_, reg_m_src_, nrows, m_block_tail_);
|
|
jmp(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_block_tail_done);
|
|
}
|
|
if (IMPLICATION(
|
|
last_m_block_tail_ <= 0 || last_m_block_tail_ == m_block_tail_,
|
|
is_dynamic_src_ld_)) {
|
|
Label last_m_block_tail_done;
|
|
cmp(reg_loop_m_, 0);
|
|
jle(last_m_block_tail_done, T_NEAR);
|
|
|
|
deploy_transpose(reg_m_dst_, reg_m_src_, nrows,
|
|
is_dynamic_src_ld_ ? 0 : last_m_block_tail_);
|
|
L(last_m_block_tail_done);
|
|
}
|
|
|
|
L(compute_m_loop_done);
|
|
|
|
if (do_compute_compensation_) mov(reg_zp_comp_prev_data_, 1);
|
|
}
|
|
|
|
void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::compute_k_loop(
|
|
bool is_first_K_iter, bool is_last_K_iter) {
|
|
mov(reg_k_src_, ptr[rsp + src_offt_]);
|
|
mov(reg_k_dst_, ptr[rsp + tr_src_offt_]);
|
|
mov(reg_loop_k_, ptr[rsp + current_K_blk_offt_]);
|
|
|
|
if (do_compute_compensation_) {
|
|
if (is_first_K_iter)
|
|
mov(reg_zp_comp_prev_data_, 0);
|
|
else
|
|
mov(reg_zp_comp_prev_data_, 1);
|
|
}
|
|
|
|
Label k_tail_or_done, k_loop, compute_k_loop_done;
|
|
cmp(reg_loop_k_, rows_step_);
|
|
jl(k_tail_or_done, T_NEAR);
|
|
|
|
L(k_loop);
|
|
{
|
|
compute_m_loop(rows_step_);
|
|
if (is_dynamic_src_ld_)
|
|
add(reg_k_src_, ptr[rsp + dynamic_src_ld_x_kstep_offt_]);
|
|
else
|
|
add(reg_k_src_, k_loop_src_shift_);
|
|
add(reg_k_dst_, k_loop_dst_shift_);
|
|
}
|
|
sub(reg_loop_k_, rows_step_);
|
|
cmp(reg_loop_k_, rows_step_);
|
|
|
|
jge(k_loop, T_NEAR);
|
|
|
|
if (k_block_tail_ > 0 || last_k_block_tail_ > 0)
|
|
jz(compute_k_loop_done, T_NEAR);
|
|
|
|
L(k_tail_or_done);
|
|
|
|
if (k_block_tail_ > 0) {
|
|
Label k_block_tail_done;
|
|
cmp(reg_loop_k_, k_block_tail_);
|
|
jne(k_block_tail_done, T_NEAR);
|
|
|
|
compute_m_loop(k_block_tail_);
|
|
jmp(compute_k_loop_done, T_NEAR);
|
|
|
|
L(k_block_tail_done);
|
|
}
|
|
if (last_k_block_tail_ > 0 && last_k_block_tail_ != k_block_tail_) {
|
|
Label last_k_block_tail_done;
|
|
cmp(reg_loop_k_, last_k_block_tail_);
|
|
jne(last_k_block_tail_done, T_NEAR);
|
|
|
|
compute_m_loop(last_k_block_tail_);
|
|
jmp(compute_k_loop_done, T_NEAR);
|
|
|
|
L(last_k_block_tail_done);
|
|
}
|
|
|
|
L(compute_k_loop_done);
|
|
|
|
if (do_compute_compensation_ && is_last_K_iter) {
|
|
mov(reg_zp_comp_res_ptr_,
|
|
ptr[param1 + GET_OFF(zp_a_compensation_result_ptr)]);
|
|
|
|
auto calculate_final_compensation = [this]() {
|
|
// load accumulated compensation
|
|
vmovups(zmm_comp_acc_, ptr[reg_zp_comp_res_ptr_]);
|
|
|
|
// add -K * zp_a_val as mixed ab_compensation component
|
|
if (conf_->src_zp_type != brgemm_broadcast_t::none) {
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(zp_ab_comp_ptr)]);
|
|
vbroadcastss(get_zmm_src(0), ptr[reg_tmp_]);
|
|
vpaddd(zmm_comp_acc_, zmm_comp_acc_, get_zmm_src(0));
|
|
}
|
|
|
|
// multiply by zp_b_val
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(zp_b_neg_val_ptr)]);
|
|
vbroadcastss(get_zmm_src(0), ptr[reg_tmp_]);
|
|
vpmulld(zmm_comp_acc_, zmm_comp_acc_, get_zmm_src(0));
|
|
|
|
// store the final result value
|
|
vmovups(ptr[reg_zp_comp_res_ptr_], zmm_comp_acc_);
|
|
add(reg_zp_comp_res_ptr_, sizeof(int32_t) * 16);
|
|
};
|
|
|
|
Label m_loop, m_loop_tail_or_done, compute_m_loop_done;
|
|
|
|
mov(reg_loop_m_, ptr[rsp + current_M_blk_offt_]);
|
|
cmp(reg_loop_m_, columns_step_);
|
|
jl(m_loop_tail_or_done, T_NEAR);
|
|
|
|
L(m_loop);
|
|
calculate_final_compensation();
|
|
sub(reg_loop_m_, columns_step_);
|
|
cmp(reg_loop_m_, columns_step_);
|
|
jge(m_loop, T_NEAR);
|
|
|
|
if (m_block_tail_ > 0 || last_m_block_tail_ > 0 || is_dynamic_src_ld_)
|
|
jz(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_loop_tail_or_done);
|
|
|
|
if (m_block_tail_ > 0) {
|
|
Label m_block_tail_done;
|
|
cmp(reg_loop_m_, m_block_tail_);
|
|
jne(m_block_tail_done, T_NEAR);
|
|
|
|
calculate_final_compensation();
|
|
jmp(compute_m_loop_done, T_NEAR);
|
|
|
|
L(m_block_tail_done);
|
|
}
|
|
if (IMPLICATION(last_m_block_tail_ <= 0
|
|
|| last_m_block_tail_ == m_block_tail_,
|
|
is_dynamic_src_ld_)) {
|
|
Label last_m_block_tail_done;
|
|
cmp(reg_loop_m_, 0);
|
|
jle(last_m_block_tail_done, T_NEAR);
|
|
|
|
calculate_final_compensation();
|
|
L(last_m_block_tail_done);
|
|
}
|
|
|
|
L(compute_m_loop_done);
|
|
}
|
|
}
|
|
|
|
void jit_brgemm_matmul_copy_a_transposed_int8_impl_t::generate() {
|
|
preamble();
|
|
sub(rsp, stack_space_needed_);
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(current_M_blk)]);
|
|
mov(ptr[rsp + current_M_blk_offt_], reg_tmp_);
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(src)]);
|
|
mov(ptr[rsp + src_offt_], reg_tmp_);
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(ptr[rsp + tr_src_offt_], reg_tmp_);
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(current_K_blk)]);
|
|
mov(ptr[rsp + current_K_blk_offt_], reg_tmp_);
|
|
if (is_dynamic_src_ld_) {
|
|
// dynamic src_stride
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(dynamic_src_ld)]);
|
|
mov(ptr[rsp + dynamic_src_ld_offt_], reg_tmp_);
|
|
|
|
// src_stride * 2
|
|
shl(reg_tmp_, 1);
|
|
mov(ptr[rsp + dynamic_src_ld_x_2_offt_], reg_tmp_);
|
|
|
|
// src_stride * rows_step_
|
|
assert(rows_step_ == 16);
|
|
shl(reg_tmp_, 3);
|
|
mov(ptr[rsp + dynamic_src_ld_x_kstep_offt_], reg_tmp_);
|
|
}
|
|
|
|
auto kmovw = [this](Opmask k, unsigned w) {
|
|
mov(reg_tmp_, w);
|
|
jit_generator_t::kmovw(k, reg_tmp_.cvt32());
|
|
};
|
|
|
|
kmovw(kFFFF_, 0xffff);
|
|
kmovw(k5555_, 0x5555);
|
|
kmovw(kAAAA_, 0xaaaa);
|
|
kmovw(kAA_, 0xaa);
|
|
kmovw(k55_, 0x55);
|
|
kmovw(kCC_, 0xcc);
|
|
kmovw(k33_, 0x33);
|
|
|
|
auto vmovdqa64 = [this](Zmm z, const int64_t *addr) {
|
|
mov(reg_tmp_, reinterpret_cast<size_t>(addr));
|
|
jit_generator_t::vmovdqa64(z, ptr[reg_tmp_]);
|
|
};
|
|
|
|
if (has_vpermb_) {
|
|
alignas(64) static constexpr const uint8_t idx1[64] = {0, 16, 32, 48, 1,
|
|
17, 33, 49, 2, 18, 34, 50, 3, 19, 35, 51, 4, 20, 36, 52, 5, 21,
|
|
37, 53, 6, 22, 38, 54, 7, 23, 39, 55, 8, 24, 40, 56, 9, 25, 41,
|
|
57, 10, 26, 42, 58, 11, 27, 43, 59, 12, 28, 44, 60, 13, 29, 45,
|
|
61, 14, 30, 46, 62, 15, 31, 47, 63};
|
|
alignas(64) static constexpr const uint32_t idx2[16]
|
|
= {0, 16, 2, 18, 4, 20, 6, 22, 8, 24, 10, 26, 12, 28, 14, 30};
|
|
alignas(64) static constexpr const uint32_t idx3[16]
|
|
= {1, 17, 3, 19, 5, 21, 7, 23, 9, 25, 11, 27, 13, 29, 15, 31};
|
|
alignas(64) static constexpr const uint64_t idx4[8]
|
|
= {0, 8, 2, 10, 4, 12, 6, 14};
|
|
alignas(64) static constexpr const uint64_t idx5[8]
|
|
= {1, 9, 3, 11, 5, 13, 7, 15};
|
|
|
|
vmovdqa64(zmm_idx_1_, (const int64_t *)idx1);
|
|
vmovdqa64(zmm_idx_2_, (const int64_t *)idx2);
|
|
vmovdqa64(zmm_idx_3_, (const int64_t *)idx3);
|
|
vmovdqa64(zmm_idx_4_, (const int64_t *)idx4);
|
|
vmovdqa64(zmm_idx_5_, (const int64_t *)idx5);
|
|
} else {
|
|
alignas(64) static constexpr const uint16_t idx1[32]
|
|
= {0, 16, 2, 18, 8, 24, 10, 26, 4, 20, 6, 22, 12, 28, 14, 30, 1,
|
|
17, 3, 19, 9, 25, 11, 27, 5, 21, 7, 23, 13, 29, 15, 31};
|
|
alignas(64) static constexpr const uint32_t idx2[16]
|
|
= {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
|
|
alignas(64) static constexpr const uint64_t idx3[8]
|
|
= {1, 0, 3, 2, 5, 4, 7, 6};
|
|
alignas(64) static constexpr const uint64_t idx4[8]
|
|
= {2, 3, 0, 1, 6, 7, 4, 5};
|
|
|
|
vmovdqa64(zmm_idx_1_, (const int64_t *)idx1);
|
|
vmovdqa64(zmm_idx_2_, (const int64_t *)idx2);
|
|
vmovdqa64(zmm_idx_3_, (const int64_t *)idx3);
|
|
vmovdqa64(zmm_idx_4_, (const int64_t *)idx4);
|
|
}
|
|
|
|
Label done;
|
|
if (do_compute_compensation_) {
|
|
assert(conf_->wei_zp_type == brgemm_broadcast_t::per_tensor);
|
|
|
|
mov(reg_tmp_, ptr[param1 + GET_OFF(current_K_start)]);
|
|
const auto last_K_threshold
|
|
= rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk;
|
|
Label not_first, not_first_not_last;
|
|
cmp(reg_tmp_, 0);
|
|
jne(not_first, T_NEAR);
|
|
{
|
|
Label first_not_last;
|
|
cmp(reg_tmp_, last_K_threshold);
|
|
jl(first_not_last, T_NEAR);
|
|
compute_k_loop(true, true);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(first_not_last);
|
|
compute_k_loop(true, false);
|
|
jmp(done, T_NEAR);
|
|
}
|
|
|
|
L(not_first);
|
|
cmp(reg_tmp_, last_K_threshold);
|
|
jl(not_first_not_last, T_NEAR);
|
|
|
|
compute_k_loop(false, true);
|
|
jmp(done, T_NEAR);
|
|
L(not_first_not_last);
|
|
}
|
|
compute_k_loop(false, false);
|
|
L(done);
|
|
|
|
add(rsp, stack_space_needed_);
|
|
postamble();
|
|
}
|
|
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>;
|
|
|
|
/**
|
|
* @brief Common class for BRGEMM B matrix copy operations
|
|
*
|
|
* This class contains common methods and properties for all copy B kernels.
|
|
* Now it consists of `load_value` and `decompress_reg` and it's considered
|
|
* to contain all common methods for copy B kernels.
|
|
* Default scenario for weight decompression is:
|
|
* 1. load_value()
|
|
* 2. apply zero point shift (if needed)
|
|
* 3. convert to f32
|
|
* 4. apply scaling (if needed)
|
|
* 5. down convert if destination datatype is not f32.
|
|
*/
|
|
struct jit_brgemm_matmul_copy_b_common_t : public jit_brgemm_matmul_copy_b_t,
|
|
public jit_generator_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_common_t)
|
|
|
|
jit_brgemm_matmul_copy_b_common_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_t(conf), jit_generator_t(jit_name()) {}
|
|
|
|
protected:
|
|
/**
|
|
* @brief Conditionally applies a mask to a vector register for tail or specialized processing
|
|
*
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param vmm The vector register to potentially mask
|
|
* @param is_tail Flag indicating if this is tail processing that requires masking
|
|
* @param is_int4 Flag indicating if this is INT4 data requiring specialized masking
|
|
* @return The potentially masked vector register (original if no masking needed)
|
|
*/
|
|
template <typename Vmm>
|
|
Vmm maybe_mask(Vmm vmm, bool is_tail, bool is_int4 = false) {
|
|
// Transposed kernel uses the same kTail mask for both cases
|
|
const auto tail_mask
|
|
= is_int4 && !conf_->transposed_B ? kTail_int4 : kTail;
|
|
const auto unmask_tail
|
|
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16)
|
|
&& !conf_->transposed_B;
|
|
if (isa_has_masks(conf_->isa))
|
|
return is_tail ? vmm | tail_mask | T_z
|
|
// bf16 and f16 requires masking for tail
|
|
// to avoid AVX512F issues with zeroing upper bits
|
|
// of ZMM registers when using vpmovzxwd/vpmovzxbd
|
|
// instructions
|
|
: unmask_tail ? vmm | kFFFF | T_z
|
|
: vmm;
|
|
else
|
|
return vmm;
|
|
}
|
|
|
|
/**
|
|
* @brief Inserts a YMM register into the upper half of a ZMM register
|
|
* @param zmm Destination ZMM register where data will be inserted
|
|
* @param ymm_half Source YMM register containing data for the upper half
|
|
*/
|
|
void copy_half_int4(const Zmm &zmm, const Ymm &ymm_half) {
|
|
vinserti64x4(zmm, zmm, ymm_half, 1);
|
|
}
|
|
|
|
/**
|
|
* @brief Inserts an XMM register into the upper half of a YMM register
|
|
* @param ymm Destination YMM register where data will be inserted
|
|
* @param xmm_half Source XMM register containing data for the upper half
|
|
*/
|
|
void copy_half_int4(const Ymm &ymm, const Xmm &xmm_half) {
|
|
vinserti128(ymm, ymm, xmm_half, 1);
|
|
}
|
|
|
|
/**
|
|
* @brief Loads and converts data of various types into vector registers with appropriate handling
|
|
*
|
|
* @tparam Vmm Vector register type for computation
|
|
* @param reg Destination vector register
|
|
* @param op Source memory operand to load from
|
|
* @param vmm_permd Vector register containing permutation indices for INT4 processing
|
|
* @param dt Data type being loaded
|
|
* @param is_tail Flag indicating if tail processing is needed
|
|
* @param vmm_f4_lut Vector register containing lookup table for FP4 conversion
|
|
* (default is Vmm(4) for kernel jit_brgemm_matmul_copy_b_f32_t)
|
|
*/
|
|
template <typename Vmm>
|
|
void load_value(const Vmm ®, const Xbyak::Operand &op,
|
|
const Vmm &vmm_permd, data_type_t dt, bool is_tail = false,
|
|
const Vmm &vmm_f4_lut = Vmm(4)) {
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
const auto vmm_in = maybe_mask(reg, is_tail);
|
|
const auto vmm_lower = Vmm_lower_t(vmm_in.getIdx());
|
|
MAYBE_UNUSED(vmm_lower);
|
|
|
|
const bool is_xf16
|
|
= one_of(conf_->wei_dt, data_type::bf16, data_type::f16);
|
|
|
|
switch (dt) {
|
|
case data_type::s32: vmovdqu32(vmm_in, op); break;
|
|
case data_type::f32: {
|
|
if (conf_->transposed_B)
|
|
vmovdqu8(vmm_in, op);
|
|
else
|
|
uni_vmovups(vmm_in, op);
|
|
break;
|
|
}
|
|
case data_type::f16:
|
|
if (!is_xf16) {
|
|
uni_vcvtph2psx(vmm_in, op);
|
|
break;
|
|
}
|
|
case data_type::bf16:
|
|
if (is_xf16) {
|
|
uni_vmovdqu16(vmm_in, op);
|
|
} else {
|
|
uni_vpmovzxwd(vmm_in, op);
|
|
uni_vpslld(vmm_in, vmm_in, 16);
|
|
}
|
|
break;
|
|
case data_type::s8: uni_vpmovsxbd(vmm_in, op); break;
|
|
case data_type::u8: uni_vpmovzxbd(vmm_in, op); break;
|
|
// For int4, we see two int4 as one int8 and extend them int32
|
|
// low half stores in lower bytes of vmm and high half in higher
|
|
// bytes of vmm, then permute them into correct order
|
|
// Finally, we process the extend bytes for s4/u4 accordingly
|
|
case data_type::s4:
|
|
uni_vpmovsxbd(
|
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
|
op);
|
|
copy_half_int4(reg, vmm_lower);
|
|
|
|
vpermd(reg, vmm_permd, reg);
|
|
uni_vpslld(reg | k5555, reg, 28);
|
|
vpsrad(reg | k5555, reg, 28);
|
|
vpsrad(reg | kAAAA, reg, 4);
|
|
break;
|
|
case data_type::u4:
|
|
uni_vpmovzxbd(
|
|
maybe_mask(vmm_lower, is_tail, /* is_int4 = */ true),
|
|
op);
|
|
copy_half_int4(reg, vmm_lower);
|
|
|
|
vpermd(reg, vmm_permd, reg);
|
|
uni_vpslld(reg | k5555, reg, 28);
|
|
vpsrld(reg | k5555, reg, 28);
|
|
vpsrld(reg | kAAAA, reg, 4);
|
|
break;
|
|
case data_type::f4_e2m1:
|
|
case data_type::f4_e3m0:
|
|
uni_vpmovzxbd(maybe_mask(vmm_lower, is_tail), op);
|
|
copy_half_int4(vmm_in, vmm_lower);
|
|
vpermd(vmm_in, vmm_permd, vmm_in);
|
|
uni_vpslld(vmm_in | k5555, vmm_in, 28);
|
|
vpsrld(vmm_in | k5555, vmm_in, 28);
|
|
vpsrld(vmm_in | kAAAA, vmm_in, 4);
|
|
vpermps(vmm_in, vmm_in, vmm_f4_lut);
|
|
break;
|
|
default: assert(!"unsupported data type");
|
|
}
|
|
}
|
|
|
|
/** @brief Loads common zero point value and broadcasts over `zp_vmm` register.
|
|
* Handles only per_k and common values.
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param zp_vmm Vector register to load and broadcast zero point value into
|
|
* @param ptr_reg Register containing pointer to zero point value in memory
|
|
**/
|
|
template <typename Vmm>
|
|
void load_common_zp_value(const Vmm &zp_vmm, const Xbyak::Reg64 &ptr_reg) {
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
// Handle only per_k and common values
|
|
const bool only_per_k
|
|
= conf_->is_wei_zp_per_k && !conf_->is_wei_zp_per_n;
|
|
const bool require_load = conf_->has_zero_point_b
|
|
&& (conf_->is_wei_zp_common || only_per_k);
|
|
if (!require_load) return;
|
|
|
|
const auto zp_dt = conf_->wei_zp_dt;
|
|
const auto tmp_xmm = Xmm(zp_vmm.getIdx());
|
|
const auto vmm_lower = Vmm_lower_t(zp_vmm.getIdx());
|
|
MAYBE_UNUSED(tmp_xmm);
|
|
MAYBE_UNUSED(vmm_lower);
|
|
const auto &addr = ptr[ptr_reg];
|
|
|
|
const bool need_upconvert = one_of(zp_dt, data_type::s8, data_type::u8,
|
|
data_type::u4, data_type::s4);
|
|
|
|
if (need_upconvert) {
|
|
uni_vpinsrb(tmp_xmm, tmp_xmm, addr, 0);
|
|
if (one_of(zp_dt, data_type::s4, data_type::s8))
|
|
uni_vpmovsxbd(tmp_xmm, tmp_xmm);
|
|
else
|
|
uni_vpmovzxbd(tmp_xmm, tmp_xmm);
|
|
|
|
// For 4-bit int need to shift left on 28 bits
|
|
if (one_of(zp_dt, data_type::s4, data_type::u4))
|
|
uni_vpslld(tmp_xmm, tmp_xmm, 28);
|
|
// Then shift back to the right on 28 bits
|
|
if (zp_dt == data_type::u4) vpsrld(tmp_xmm, tmp_xmm, 28);
|
|
if (zp_dt == data_type::s4) vpsrad(tmp_xmm, tmp_xmm, 28);
|
|
}
|
|
|
|
const auto &op = need_upconvert
|
|
? static_cast<const Xbyak::Operand &>(tmp_xmm)
|
|
: static_cast<const Xbyak::Operand &>(addr);
|
|
uni_vpbroadcastd(zp_vmm, op);
|
|
}
|
|
|
|
/** @brief Loads common scale value and broadcasts over `scale_vmm` register.
|
|
* Handles only per_k and common values.
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param scale_vmm Vector register to load and broadcast scale value into
|
|
* @param ptr_reg Register containing pointer to scale value in memory
|
|
**/
|
|
template <typename Vmm>
|
|
void load_common_scale_value(
|
|
const Vmm &scale_vmm, const Xbyak::Reg64 &ptr_reg) {
|
|
const bool only_per_k
|
|
= conf_->is_wei_scale_per_k && !conf_->is_wei_scale_per_n;
|
|
const bool require_scales = conf_->apply_scales_in_buffer_b
|
|
&& (conf_->is_wei_scale_common || only_per_k);
|
|
if (!require_scales) return;
|
|
|
|
const auto &scales_dt = conf_->wei_scales_dt;
|
|
const auto &addr = ptr[ptr_reg];
|
|
switch (scales_dt) {
|
|
case data_type::f32: uni_vbroadcastss(scale_vmm, addr); break;
|
|
case data_type::bf16:
|
|
vpbroadcastw(scale_vmm, addr);
|
|
uni_vpslld(scale_vmm, scale_vmm, 16);
|
|
break;
|
|
case data_type::f16: vcvtph2psx(scale_vmm, addr); break;
|
|
default: assert(!"unsupported wei_scales data type");
|
|
}
|
|
}
|
|
|
|
/** @brief Helper method to load scales into vector register.
|
|
* Supports f32, bf16 and f16 data types.
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param vmm Vector register to load scale value into
|
|
* @param op Operand to load scale value from
|
|
**/
|
|
template <typename Vmm>
|
|
void load_scale_value(const Vmm &vmm, const Xbyak::Operand &op,
|
|
data_type_t dt, bool is_tail = false) {
|
|
const auto masked_vmm = maybe_mask(vmm, is_tail);
|
|
switch (dt) {
|
|
case data_type::f32: uni_vmovups(masked_vmm, op); break;
|
|
case data_type::bf16:
|
|
uni_vpmovzxwd(masked_vmm, op);
|
|
uni_vpslld(vmm, vmm, 16);
|
|
break;
|
|
case data_type::f16: vcvtph2ps(masked_vmm, op); break;
|
|
default: assert(!"unsupported wei_scales data type");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Applies zero point shift to vector register
|
|
* Shifts input values by subtracting zero point values.
|
|
*
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param input Vector register containing input values to be shifted
|
|
* @param zp Vector register containing zero point values to be subtracted
|
|
* @param src_dt Source data type that determines which instruction to use
|
|
*/
|
|
template <typename Vmm>
|
|
void apply_shift(const Vmm &input, const Vmm &zp, data_type_t src_dt) {
|
|
if (!conf_->has_zero_point_b) return;
|
|
|
|
const bool is_int_shift = one_of(src_dt, data_type::s8, data_type::u8,
|
|
data_type::s4, data_type::u4);
|
|
const bool is_fp_shift = one_of(
|
|
src_dt, data_type::bf16, data_type::f16, data_type::f32);
|
|
|
|
if (is_int_shift)
|
|
uni_vpsubd(input, input, zp);
|
|
else if (is_fp_shift)
|
|
uni_vsubps(input, input, zp);
|
|
// should be unreachable
|
|
else
|
|
assert(!"Unable to shift input for zero-point");
|
|
}
|
|
|
|
/**
|
|
* @brief Converts integer data types to floating-point format
|
|
*
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param input Vector register containing data to be converted
|
|
* @param src_dt Source data type of the values to be converted
|
|
*/
|
|
template <typename Vmm>
|
|
void upconvert_to_f32(const Vmm &input, data_type_t src_dt) {
|
|
switch (src_dt) {
|
|
case data_type::s8:
|
|
case data_type::u8:
|
|
case data_type::s4:
|
|
case data_type::u4: uni_vcvtdq2ps(input, input); break;
|
|
case data_type::bf16:
|
|
case data_type::f16:
|
|
case data_type::f32:
|
|
// bf16 and f16 already converted into f32 while loading
|
|
break;
|
|
default: assert(!"Unsupported source data type for decompression");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Applies scale factors to input values when configured
|
|
* The operation is only performed if apply_scales_in_buffer_b flag is set
|
|
* in the configuration.
|
|
*
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param input Vector register containing values to be scaled
|
|
* @param scale_op Operand containing scale factors to apply
|
|
*/
|
|
template <typename Vmm>
|
|
void apply_scales(const Vmm &input, const Xbyak::Operand &scale_op) {
|
|
if (conf_->apply_scales_in_buffer_b) uni_vmulps(input, input, scale_op);
|
|
}
|
|
|
|
/**
|
|
* @brief Converts floating-point data from F32 to specified destination data type
|
|
*
|
|
* @tparam Vmm Vector register type (Zmm, Ymm, etc.)
|
|
* @param input Vector register containing F32 values to be converted
|
|
* @param dst_dt Destination data type for conversion
|
|
*/
|
|
template <typename Vmm>
|
|
void downconvert_to_dst_dt(const Vmm &input, data_type_t dst_dt) {
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
switch (dst_dt) {
|
|
case data_type::bf16:
|
|
vcvtneps2bf16(Ymm(input.getIdx()), input);
|
|
break;
|
|
case data_type::f16:
|
|
vcvtps2phx(Vmm_lower_t(input.getIdx()), input);
|
|
break;
|
|
case data_type::f32:
|
|
// f32 is already in the correct format
|
|
break;
|
|
default:
|
|
assert(!"Unsupported destination data type for decompression");
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void downconvert_to_dst_dt(
|
|
const Vmm ®1, const Vmm ®2, data_type_t dst_dt) {
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
switch (dst_dt) {
|
|
case data_type::bf16: vcvtne2ps2bf16(reg1, reg2, reg1); break;
|
|
case data_type::f16: {
|
|
const auto src_vmm_lower0 = Vmm_lower_t(reg1.getIdx());
|
|
const auto src_vmm_lower1 = Vmm_lower_t(reg2.getIdx());
|
|
vcvtps2phx(src_vmm_lower0, reg1);
|
|
vcvtps2phx(src_vmm_lower1, reg2);
|
|
vinsertf64x4(reg1, reg1, src_vmm_lower1, 1);
|
|
break;
|
|
}
|
|
case data_type::f32:
|
|
// f32 is already in the correct format
|
|
break;
|
|
default:
|
|
assert(!"Unsupported destination data type for decompression");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Decompresses values by applying zero point shift, data type conversion, and scaling.
|
|
* @tparam Vmm Vector register type used for computation
|
|
* @param input Vector register containing input values to be decompressed
|
|
* @param zp Vector register containing zero point values
|
|
* @param scale_op Operand containing scales to be applied
|
|
* @param src_dt Source data type of the values being decompressed
|
|
* @param dst_dt Destination data type for the decompressed values
|
|
*/
|
|
template <typename Vmm>
|
|
void decompress_reg(const Vmm &input, const Vmm &zp,
|
|
const Xbyak::Operand &scale_op, data_type_t src_dt) {
|
|
if (src_dt == data_type::f32)
|
|
return; // Decompression doesn't support f32
|
|
apply_shift(input, zp, src_dt);
|
|
upconvert_to_f32(input, src_dt);
|
|
apply_scales(input, scale_op);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void decompress_and_downcvt_reg(const Vmm &input, const Vmm &zp,
|
|
const Xbyak::Operand &scale_op, data_type_t src_dt,
|
|
data_type_t dst_dt) {
|
|
if (src_dt == dst_dt) return;
|
|
decompress_reg(input, zp, scale_op, src_dt);
|
|
downconvert_to_dst_dt(input, dst_dt);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void decompress_and_downcvt_2reg(const Vmm &input1, const Vmm &input2,
|
|
const Vmm &zp1, const Vmm &zp2, const Xbyak::Operand &scale_op1,
|
|
const Xbyak::Operand &scale_op2, data_type_t src_dt,
|
|
data_type_t dst_dt) {
|
|
if (src_dt == dst_dt) return;
|
|
decompress_reg(input1, zp1, scale_op1, src_dt);
|
|
decompress_reg(input2, zp2, scale_op2, src_dt);
|
|
downconvert_to_dst_dt(input1, input2, dst_dt);
|
|
}
|
|
|
|
// Common used masks to permute data
|
|
using opmask_t = const Xbyak::Opmask;
|
|
opmask_t k3333 = k1;
|
|
opmask_t k5555 = k2;
|
|
opmask_t kAAAA = k3;
|
|
opmask_t kCCCC = k4;
|
|
opmask_t k0F0F = k5;
|
|
opmask_t kF0F0 = k6;
|
|
opmask_t kFFFF = k6;
|
|
opmask_t kTail = k7;
|
|
opmask_t kTail_int4 = k5; // TODO: refactor: use kTail for both cases
|
|
};
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_b_int8_t : public jit_brgemm_matmul_copy_b_t,
|
|
public jit_generator_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_int8_t)
|
|
|
|
jit_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_t(conf)
|
|
, jit_generator_t(jit_name())
|
|
, src_stride_(conf->copy_B_wei_stride)
|
|
, tr_src_stride_(conf->LDB * k_blk_step_ * sizeof(int8_t))
|
|
, is_amx_(mayiuse(avx512_core_amx))
|
|
, do_compute_compensation_(
|
|
conf->s8s8_compensation_required || conf->has_zero_point_a)
|
|
, avx512_core_dot_product_(
|
|
do_compute_compensation_ && !isa_has_int8_vnni(conf->isa))
|
|
, is_dynamic_stride_(is_runtime_value(src_stride_))
|
|
, is_dynamic_N_(conf->is_runtime_N)
|
|
, comp_acc_idx_(is_ymm_ ? 13
|
|
: avx512_core_dot_product_ ? 23
|
|
: 25) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
protected:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
|
|
static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
|
|
static constexpr int k_blk_step_ = 4;
|
|
static constexpr int n_blk_step_ = 64;
|
|
static constexpr int blk_sz_ = 6;
|
|
static constexpr int simd_w_ = vreg_traits_t<Vmm>::vlen;
|
|
|
|
const dim_t src_stride_;
|
|
const dim_t tr_src_stride_;
|
|
const bool is_amx_;
|
|
const bool do_compute_compensation_;
|
|
const bool avx512_core_dot_product_;
|
|
const bool is_dynamic_stride_;
|
|
const bool is_dynamic_N_;
|
|
|
|
constexpr static int reg_src_offs_ = 0;
|
|
constexpr static int reg_tr_src_offs_ = 8;
|
|
constexpr static int reg_current_K_pad_offs_ = 16;
|
|
constexpr static int stack_space_needed_ = 24;
|
|
|
|
const int comp_acc_idx_;
|
|
|
|
const Xbyak::Opmask kTail = k7;
|
|
|
|
reg64_t reg_src = rax;
|
|
reg64_t reg_tr_src = rbx;
|
|
reg64_t reg_comp_ptr = rdx;
|
|
reg64_t reg_zp_comp_ptr = r11;
|
|
reg64_t reg_zp_a_neg_val_ptr = r12;
|
|
|
|
reg64_t reg_K_iters = r8;
|
|
reg64_t reg_N_blk = r9;
|
|
reg64_t reg_K_start = r10;
|
|
reg64_t reg_src_stride = r13;
|
|
reg64_t reg_src_backup = r14;
|
|
reg64_t reg_tmp = r15;
|
|
|
|
reg64_t reg_copy_block_n_shift = rsi;
|
|
|
|
reg64_t reg_dynamic_tail = rcx;
|
|
Xbyak::Reg8 reg8_mask_shift = reg_dynamic_tail.cvt8();
|
|
|
|
// Required in every dot product for INT8 non-VNNI computation.
|
|
Vmm vmm_ones_words = Vmm(24);
|
|
Vmm vmm_dot_product_temp = Vmm(25);
|
|
|
|
// ZMM stuff
|
|
Vmm vreg_idx_lo_256 = Vmm(26);
|
|
Vmm vreg_idx_hi_256 = Vmm(27);
|
|
Vmm vreg_idx_lo_128 = Vmm(28);
|
|
Vmm vreg_idx_hi_128 = Vmm(29);
|
|
|
|
// Shared
|
|
Vmm vmm_comp_mul = Vmm(is_ymm_ ? 14 : 30);
|
|
Vmm vmm_zero = Vmm(is_ymm_ ? 15 : 31);
|
|
|
|
Vmm get_comp_acc(int i) { return Vmm(comp_acc_idx_ - i); }
|
|
Vmm get_vmm_zp_comp_res(int i) { return get_comp_acc(i); }
|
|
Vmm get_vmm_wei_scale_comp_res(int i) { return Vmm(i); }
|
|
|
|
inline void vmovdqa64(Vmm vmm, const void *addr) {
|
|
mov(reg_tmp, reinterpret_cast<size_t>(addr));
|
|
jit_generator_t::vmovdqa64(vmm, ptr[reg_tmp]);
|
|
}
|
|
|
|
inline Vmm get_vmm(int blk, int idx) {
|
|
if (idx < 0 || idx >= isa_num_vregs(is_ymm_ ? avx2 : avx512_core))
|
|
assert(!"idx > vregs");
|
|
assert(IMPLICATION(!is_ymm_, idx < blk_sz_ && blk >= 0));
|
|
auto reg_idx = blk_sz_ * blk + idx;
|
|
return Vmm(reg_idx);
|
|
}
|
|
inline void load(int blk, int i, bool is_tail) {}
|
|
inline void kmovq(Opmask k, size_t q) {}
|
|
virtual void init_permute() {}
|
|
virtual void copy_block(
|
|
int nrows, int ncolumns, bool n_tail, bool zeropad) {
|
|
UNUSED(n_tail);
|
|
copy_4x64(nrows, ncolumns, zeropad);
|
|
}
|
|
virtual void copy_4x64(int nrows, int ncolumns, bool zeropad) {}
|
|
inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
|
|
if (!avx512_core_dot_product_)
|
|
vpdpbusd(v1, v2, v3, get_encoding());
|
|
else {
|
|
vpmaddubsw(vmm_dot_product_temp, v2, v3);
|
|
vpmaddwd(
|
|
vmm_dot_product_temp, vmm_dot_product_temp, vmm_ones_words);
|
|
vpaddd(v1, v1, vmm_dot_product_temp);
|
|
}
|
|
}
|
|
void generate() override;
|
|
};
|
|
|
|
template <>
|
|
inline void jit_brgemm_matmul_copy_b_int8_t<Zmm>::load(
|
|
int blk, int i, bool is_tail) {
|
|
auto vmm_src = get_vmm(blk, i % k_blk_step_);
|
|
auto src_load = is_tail ? vmm_src | kTail | T_z : vmm_src;
|
|
const auto offset = is_dynamic_stride_ ? 0 : i * src_stride_;
|
|
vmovdqu8(src_load, EVEX_compress_addr(reg_src, offset));
|
|
if (is_dynamic_stride_) add(reg_src, reg_src_stride);
|
|
}
|
|
|
|
template <>
|
|
inline void jit_brgemm_matmul_copy_b_int8_t<Zmm>::kmovq(Opmask k, size_t q) {
|
|
if (is_dynamic_N_) {
|
|
mov(reg_tmp, 1);
|
|
shl(reg_tmp, reg8_mask_shift /* reg_dynamic_tail.cvt8() == cl */);
|
|
sub(reg_tmp, 1);
|
|
} else
|
|
mov(reg_tmp, q);
|
|
jit_generator_t::kmovq(k, reg_tmp);
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_b_int8_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_b_int8_t<Ymm>;
|
|
|
|
struct jit_amx_brgemm_matmul_copy_b_int8_t
|
|
: public jit_brgemm_matmul_copy_b_int8_t<Xbyak::Zmm> {
|
|
|
|
jit_amx_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_int8_t<Xbyak::Zmm>(conf)
|
|
, do_N_loop_(conf->LDB < conf->N_blk) {}
|
|
|
|
private:
|
|
const bool do_N_loop_;
|
|
|
|
void init_permute() override {
|
|
alignas(64) static constexpr const uint8_t idx_lo_16[64] = {0, 1, 64,
|
|
65, 4, 5, 68, 69, 2, 3, 66, 67, 6, 7, 70, 71, 8, 9, 72, 73, 12,
|
|
13, 76, 77, 10, 11, 74, 75, 14, 15, 78, 79, 16, 17, 80, 81, 20,
|
|
21, 84, 85, 18, 19, 82, 83, 22, 23, 86, 87, 24, 25, 88, 89, 28,
|
|
29, 92, 93, 26, 27, 90, 91, 30, 31, 94, 95};
|
|
|
|
alignas(64) static constexpr const uint8_t idx_hi_16[64] = {32, 33, 96,
|
|
97, 36, 37, 100, 101, 34, 35, 98, 99, 38, 39, 102, 103, 40, 41,
|
|
104, 105, 44, 45, 108, 109, 42, 43, 106, 107, 46, 47, 110, 111,
|
|
48, 49, 112, 113, 52, 53, 116, 117, 50, 51, 114, 115, 54, 55,
|
|
118, 119, 56, 57, 120, 121, 60, 61, 124, 125, 58, 59, 122, 123,
|
|
62, 63, 126, 127};
|
|
|
|
alignas(64) static constexpr const uint8_t idx_lo_8[64] = {0, 64, 2, 66,
|
|
1, 65, 3, 67, 8, 72, 10, 74, 9, 73, 11, 75, 4, 68, 6, 70, 5, 69,
|
|
7, 71, 12, 76, 14, 78, 13, 77, 15, 79, 16, 80, 18, 82, 17, 81,
|
|
19, 83, 24, 88, 26, 90, 25, 89, 27, 91, 20, 84, 22, 86, 21, 85,
|
|
23, 87, 28, 92, 30, 94, 29, 93, 31, 95};
|
|
|
|
alignas(64) static constexpr const uint8_t idx_hi_8[64] = {32, 96, 34,
|
|
98, 33, 97, 35, 99, 40, 104, 42, 106, 41, 105, 43, 107, 36, 100,
|
|
38, 102, 37, 101, 39, 103, 44, 108, 46, 110, 45, 109, 47, 111,
|
|
48, 112, 50, 114, 49, 113, 51, 115, 56, 120, 58, 122, 57, 121,
|
|
59, 123, 52, 116, 54, 118, 53, 117, 55, 119, 60, 124, 62, 126,
|
|
61, 125, 63, 127};
|
|
|
|
vmovdqa64(vreg_idx_lo_256, (const void *)idx_lo_16);
|
|
vmovdqa64(vreg_idx_hi_256, (const void *)idx_hi_16);
|
|
vmovdqa64(vreg_idx_lo_128, (const void *)idx_lo_8);
|
|
vmovdqa64(vreg_idx_hi_128, (const void *)idx_hi_8);
|
|
}
|
|
|
|
void copy_block(
|
|
int nrows, int ncolumns, bool n_tail, bool zeropad) override {
|
|
|
|
if (!do_N_loop_ && (!is_dynamic_N_ || !n_tail)) {
|
|
copy_4x64(nrows, ncolumns, zeropad);
|
|
return;
|
|
}
|
|
|
|
mov(reg_dynamic_tail, reg_N_blk);
|
|
// dynamic tail processing: main loop with ncolumns = n_blk_step and
|
|
// finally process tail < n_blk_step with dynamically computed mask
|
|
// NOTE: for dynamic_stride case copy_4x64() shifts reg_src pointer
|
|
// so we need to backup/restore its value for every iteration wrt n
|
|
// except the last one
|
|
|
|
mov(ptr[rsp + reg_tr_src_offs_], reg_tr_src);
|
|
xor_(reg_copy_block_n_shift, reg_copy_block_n_shift);
|
|
const auto typesize = sizeof(int8_t);
|
|
|
|
Label loop_row_start, loop_row_tail, loop_row_done;
|
|
cmp(reg_dynamic_tail, n_blk_step_);
|
|
jl(loop_row_tail, T_NEAR);
|
|
L(loop_row_start);
|
|
{
|
|
mov(ptr[rsp + reg_src_offs_], reg_src);
|
|
add(reg_src, reg_copy_block_n_shift);
|
|
copy_4x64(nrows, n_blk_step_, zeropad);
|
|
add(reg_copy_block_n_shift, n_blk_step_ * typesize);
|
|
add(reg_src, n_blk_step_ * typesize);
|
|
|
|
if (do_N_loop_)
|
|
// (n_blk_step_ /conf_->LDB) --> # of LDBs handled by copy_4x64
|
|
add(reg_tr_src,
|
|
(n_blk_step_ / conf_->LDB) * conf_->LDB2 * typesize);
|
|
else
|
|
add(reg_tr_src, n_blk_step_ * k_blk_step_ * typesize);
|
|
|
|
sub(reg_dynamic_tail, n_blk_step_);
|
|
|
|
cmp(reg_dynamic_tail, 0);
|
|
jle(loop_row_done, T_NEAR);
|
|
|
|
mov(reg_src, ptr[rsp + reg_src_offs_]);
|
|
|
|
cmp(reg_dynamic_tail, n_blk_step_);
|
|
jl(loop_row_tail, T_NEAR);
|
|
|
|
jmp(loop_row_start, T_NEAR);
|
|
}
|
|
|
|
L(loop_row_tail);
|
|
{
|
|
cmp(reg_dynamic_tail, 0);
|
|
jle(loop_row_done, T_NEAR);
|
|
|
|
add(reg_src, reg_copy_block_n_shift);
|
|
if (do_N_loop_ && !is_dynamic_N_)
|
|
copy_4x64(nrows, ncolumns % n_blk_step_, zeropad);
|
|
else
|
|
copy_4x64(nrows, 1 /* to force tail case */, zeropad);
|
|
}
|
|
L(loop_row_done);
|
|
|
|
// restore pointers
|
|
sub(reg_src, reg_copy_block_n_shift);
|
|
mov(reg_tr_src, ptr[rsp + reg_tr_src_offs_]);
|
|
}
|
|
|
|
void copy_4x64(int nrows, int ncolumns, bool zeropad) override {
|
|
|
|
auto tr_src_off_n = [&](int n_elem) {
|
|
return ((n_elem / conf_->LDB) * conf_->LDB2
|
|
+ (n_elem % conf_->LDB) * k_blk_step_);
|
|
};
|
|
|
|
const bool is_tail = ncolumns < n_blk_step_;
|
|
const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1);
|
|
|
|
if (is_tail) kmovq(kTail, tail_mask);
|
|
|
|
const int max_unroll = (do_compute_compensation_ ? 21 : 25) / blk_sz_;
|
|
|
|
for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step_); kb++)
|
|
for (int k = 0; k < nstl::min(max_unroll,
|
|
div_up(nrows - kb * max_unroll * k_blk_step_,
|
|
k_blk_step_));
|
|
k++) {
|
|
const int row_start = (kb * max_unroll + k) * k_blk_step_;
|
|
const int row_end = nstl::min(row_start + k_blk_step_, nrows);
|
|
dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride_;
|
|
|
|
if (!zeropad) {
|
|
for (int i = row_start; i < row_end; i++)
|
|
load(k, i, is_tail);
|
|
if (row_end == nrows && nrows % k_blk_step_ > 0) {
|
|
for (int i = nrows; i < rnd_up(nrows, k_blk_step_); i++) {
|
|
auto src_reg = get_vmm(k, i % k_blk_step_);
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
}
|
|
}
|
|
vmovups(get_vmm(k, 4), vreg_idx_lo_256);
|
|
vpermi2b(get_vmm(k, 4), get_vmm(k, 0), get_vmm(k, 2));
|
|
vmovups(get_vmm(k, 5), vreg_idx_hi_256);
|
|
vpermi2b(get_vmm(k, 5), get_vmm(k, 0), get_vmm(k, 2));
|
|
vmovups(get_vmm(k, 0), vreg_idx_lo_256);
|
|
vpermi2b(get_vmm(k, 0), get_vmm(k, 1), get_vmm(k, 3));
|
|
vmovups(get_vmm(k, 2), vreg_idx_hi_256);
|
|
vpermi2b(get_vmm(k, 2), get_vmm(k, 1), get_vmm(k, 3));
|
|
|
|
vmovups(get_vmm(k, 1), vreg_idx_lo_128);
|
|
vpermi2b(get_vmm(k, 1), get_vmm(k, 4), get_vmm(k, 0));
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base),
|
|
get_vmm(k, 1));
|
|
if (do_compute_compensation_)
|
|
vpdpbusd(get_comp_acc(0), vmm_comp_mul, get_vmm(k, 1));
|
|
} else {
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base),
|
|
vmm_zero);
|
|
}
|
|
const bool dynamic_tail = is_dynamic_N_ && ncolumns < n_blk_step_;
|
|
|
|
Label k_loop_done;
|
|
if (dynamic_tail) {
|
|
cmp(reg_dynamic_tail, 16);
|
|
jle(k_loop_done, T_NEAR);
|
|
}
|
|
if (!zeropad && (ncolumns > 16 || dynamic_tail)) {
|
|
vmovups(get_vmm(k, 3), vreg_idx_hi_128);
|
|
vpermi2b(get_vmm(k, 3), get_vmm(k, 4), get_vmm(k, 0));
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(16)),
|
|
get_vmm(k, 3));
|
|
if (do_compute_compensation_)
|
|
vpdpbusd(get_comp_acc(1), vmm_comp_mul, get_vmm(k, 3));
|
|
} else if (conf_->wei_n_blk > 16) {
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(16)),
|
|
vmm_zero);
|
|
}
|
|
|
|
if (dynamic_tail) {
|
|
cmp(reg_dynamic_tail, 32);
|
|
jle(k_loop_done, T_NEAR);
|
|
}
|
|
if (!zeropad && (ncolumns > 32 || dynamic_tail)) {
|
|
vmovups(get_vmm(k, 4), vreg_idx_lo_128);
|
|
vpermi2b(get_vmm(k, 4), get_vmm(k, 5), get_vmm(k, 2));
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(32)),
|
|
get_vmm(k, 4));
|
|
if (do_compute_compensation_)
|
|
vpdpbusd(get_comp_acc(2), vmm_comp_mul, get_vmm(k, 4));
|
|
} else if (conf_->wei_n_blk > 32) {
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(32)),
|
|
vmm_zero);
|
|
}
|
|
|
|
if (dynamic_tail) {
|
|
cmp(reg_dynamic_tail, 48);
|
|
jle(k_loop_done, T_NEAR);
|
|
}
|
|
if (!zeropad && (ncolumns > 48 || dynamic_tail)) {
|
|
vmovups(get_vmm(k, 0), vreg_idx_hi_128);
|
|
vpermi2b(get_vmm(k, 0), get_vmm(k, 5), get_vmm(k, 2));
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(48)),
|
|
get_vmm(k, 0));
|
|
if (do_compute_compensation_)
|
|
vpdpbusd(get_comp_acc(3), vmm_comp_mul, get_vmm(k, 0));
|
|
} else if (conf_->wei_n_blk > 48) {
|
|
vmovups(EVEX_compress_addr(
|
|
reg_tr_src, tr_src_off_base + tr_src_off_n(48)),
|
|
vmm_zero);
|
|
}
|
|
L(k_loop_done);
|
|
}
|
|
}
|
|
};
|
|
|
|
struct jit_avx512_core_brgemm_matmul_copy_b_int8_t
|
|
: public jit_brgemm_matmul_copy_b_int8_t<Xbyak::Zmm> {
|
|
|
|
jit_avx512_core_brgemm_matmul_copy_b_int8_t(
|
|
const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_int8_t<Xbyak::Zmm>(conf) {}
|
|
|
|
private:
|
|
void init_permute() override {
|
|
alignas(64) static constexpr const int64_t idx_lo_256[8]
|
|
= {0, 1, 2, 3, 8, 9, 10, 11};
|
|
alignas(64) static constexpr const int64_t idx_hi_256[8]
|
|
= {4, 5, 6, 7, 12, 13, 14, 15};
|
|
|
|
alignas(64) static constexpr const int64_t idx_lo_128[8]
|
|
= {0, 1, 8, 9, 4, 5, 12, 13};
|
|
alignas(64) static constexpr const int64_t idx_hi_128[8]
|
|
= {2, 3, 10, 11, 6, 7, 14, 15};
|
|
|
|
vmovdqa64(vreg_idx_lo_256, (const void *)idx_lo_256);
|
|
vmovdqa64(vreg_idx_hi_256, (const void *)idx_hi_256);
|
|
vmovdqa64(vreg_idx_lo_128, (const void *)idx_lo_128);
|
|
vmovdqa64(vreg_idx_hi_128, (const void *)idx_hi_128);
|
|
}
|
|
|
|
void copy_4x64(int nrows, int ncolumns, bool zeropad) override {
|
|
const bool is_tail = ncolumns < n_blk_step_;
|
|
if (is_tail) {
|
|
const auto tail_mask = size_t(((size_t)1 << ncolumns) - 1);
|
|
kmovq(kTail, tail_mask);
|
|
}
|
|
|
|
const int max_unroll = (do_compute_compensation_ ? 21 : 25) / blk_sz_;
|
|
|
|
for_(int kb = 0; kb < div_up(nrows, max_unroll * k_blk_step_); kb++)
|
|
for (int k = 0; k < nstl::min(max_unroll,
|
|
div_up(nrows - kb * max_unroll * k_blk_step_,
|
|
k_blk_step_));
|
|
k++) {
|
|
const int row_start = (kb * max_unroll + k) * k_blk_step_;
|
|
const int row_end = nstl::min(row_start + k_blk_step_, nrows);
|
|
dim_t tr_src_off_base = (kb * max_unroll + k) * tr_src_stride_;
|
|
|
|
if (!zeropad) {
|
|
for (int i = row_start; i < row_end; i++)
|
|
load(k, i, is_tail);
|
|
if (row_end == nrows && nrows % k_blk_step_ > 0) {
|
|
for (int i = nrows; i < rnd_up(nrows, k_blk_step_); i++) {
|
|
auto src_reg = get_vmm(k, i % k_blk_step_);
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
}
|
|
}
|
|
|
|
vpunpcklbw(get_vmm(k, 4), get_vmm(k, 0), get_vmm(k, 1));
|
|
vpunpckhbw(get_vmm(k, 5), get_vmm(k, 0), get_vmm(k, 1));
|
|
vpunpcklbw(get_vmm(k, 0), get_vmm(k, 2), get_vmm(k, 3));
|
|
vpunpckhbw(get_vmm(k, 1), get_vmm(k, 2), get_vmm(k, 3));
|
|
|
|
vpunpcklwd(get_vmm(k, 2), get_vmm(k, 4), get_vmm(k, 0));
|
|
vpunpckhwd(get_vmm(k, 3), get_vmm(k, 4), get_vmm(k, 0));
|
|
vpunpcklwd(get_vmm(k, 4), get_vmm(k, 5), get_vmm(k, 1));
|
|
vpunpckhwd(get_vmm(k, 5), get_vmm(k, 5), get_vmm(k, 1));
|
|
|
|
vmovups(get_vmm(k, 0), vreg_idx_lo_256);
|
|
vpermi2q(get_vmm(k, 0), get_vmm(k, 2), get_vmm(k, 4));
|
|
vmovups(get_vmm(k, 1), vreg_idx_hi_256);
|
|
vpermi2q(get_vmm(k, 1), get_vmm(k, 2), get_vmm(k, 4));
|
|
vmovups(get_vmm(k, 2), vreg_idx_lo_256);
|
|
vpermi2q(get_vmm(k, 2), get_vmm(k, 3), get_vmm(k, 5));
|
|
vmovups(get_vmm(k, 4), vreg_idx_hi_256);
|
|
vpermi2q(get_vmm(k, 4), get_vmm(k, 3), get_vmm(k, 5));
|
|
|
|
vmovups(get_vmm(k, 3), vreg_idx_lo_128);
|
|
vpermi2q(get_vmm(k, 3), get_vmm(k, 0), get_vmm(k, 2));
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base),
|
|
get_vmm(k, 3));
|
|
if (do_compute_compensation_)
|
|
dot_product(get_comp_acc(0), vmm_comp_mul, get_vmm(k, 3));
|
|
} else {
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base),
|
|
vmm_zero);
|
|
}
|
|
|
|
if (!zeropad && ncolumns > 16) {
|
|
vmovups(get_vmm(k, 5), vreg_idx_hi_128);
|
|
vpermi2q(get_vmm(k, 5), get_vmm(k, 0), get_vmm(k, 2));
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
|
|
get_vmm(k, 5));
|
|
if (do_compute_compensation_)
|
|
dot_product(get_comp_acc(1), vmm_comp_mul, get_vmm(k, 5));
|
|
} else if (conf_->wei_n_blk > 16) {
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 64),
|
|
vmm_zero);
|
|
}
|
|
|
|
if (!zeropad && ncolumns > 32) {
|
|
vmovups(get_vmm(k, 0), vreg_idx_lo_128);
|
|
vpermi2q(get_vmm(k, 0), get_vmm(k, 1), get_vmm(k, 4));
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
|
|
get_vmm(k, 0));
|
|
if (do_compute_compensation_)
|
|
dot_product(get_comp_acc(2), vmm_comp_mul, get_vmm(k, 0));
|
|
} else if (conf_->wei_n_blk > 32) {
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 128),
|
|
vmm_zero);
|
|
}
|
|
|
|
if (!zeropad && ncolumns > 48) {
|
|
vmovups(get_vmm(k, 2), vreg_idx_hi_128);
|
|
vpermi2q(get_vmm(k, 2), get_vmm(k, 1), get_vmm(k, 4));
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
|
|
get_vmm(k, 2));
|
|
if (do_compute_compensation_)
|
|
dot_product(get_comp_acc(3), vmm_comp_mul, get_vmm(k, 2));
|
|
} else if (conf_->wei_n_blk > 48) {
|
|
vmovups(EVEX_compress_addr(reg_tr_src, tr_src_off_base + 192),
|
|
vmm_zero);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
struct jit_avx2_vnni_brgemm_matmul_copy_b_int8_t
|
|
: public jit_brgemm_matmul_copy_b_int8_t<Xbyak::Ymm> {
|
|
|
|
jit_avx2_vnni_brgemm_matmul_copy_b_int8_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_int8_t<Xbyak::Ymm>(conf) {}
|
|
|
|
private:
|
|
static constexpr int perm2i128_l
|
|
= 0x20; // dst[127:0]=src1_low_128; dst[128:255]=src2_low_128
|
|
static constexpr int perm2i128_h
|
|
= 0x31; // dst[127:0]=src1_hi_128; dst[128:255]=src2_hi_128
|
|
|
|
Xbyak::Ymm get_ymm(int idx) { return get_vmm(0, idx); }
|
|
|
|
void load_ymm(int ymm_idx, size_t offset, bool is_tail, size_t tail_sz) {
|
|
Xbyak::Ymm vmm_src = Xbyak::Ymm(ymm_idx);
|
|
if (is_tail) {
|
|
load_bytes(vmm_src, reg_src, offset, tail_sz);
|
|
} else
|
|
uni_vmovups(vmm_src, ptr[reg_src + offset]);
|
|
}
|
|
|
|
void copy_4x64(int nrows, int ncolumns, bool zeropad) override {
|
|
const bool is_tail = ncolumns < n_blk_step_;
|
|
const int k_end = div_up(nrows, k_blk_step_);
|
|
for_(int k = 0; k < k_end; k++)
|
|
for (int pass = 0; pass < 2; ++pass) {
|
|
if (pass == 0 && ncolumns >= simd_w_) mov(reg_src_backup, reg_src);
|
|
assert(one_of(pass, 0, 1));
|
|
const dim_t tr_src_off_base = k * tr_src_stride_;
|
|
const int set_1_tr_src_offset
|
|
= tr_src_off_base + pass * 2 * n_blk_step_;
|
|
const int row_start = k * k_blk_step_;
|
|
const int row_end = nstl::min(row_start + k_blk_step_, nrows);
|
|
if (!zeropad) {
|
|
for (int i = row_start; i < rnd_up(row_end, k_blk_step_); i++) {
|
|
const bool do_load = i < row_end
|
|
&& IMPLICATION(pass == 1, ncolumns >= simd_w_);
|
|
if (do_load) {
|
|
const bool do_tail = is_tail
|
|
&& IMPLICATION(pass == 0, ncolumns < simd_w_);
|
|
const auto offset
|
|
= (is_dynamic_stride_ ? 0 : i * src_stride_)
|
|
+ pass * simd_w_;
|
|
load_ymm(i % 4, offset, do_tail,
|
|
ncolumns - pass * simd_w_);
|
|
if (is_dynamic_stride_) add(reg_src, reg_src_stride);
|
|
} else {
|
|
const auto src_ymm_1 = get_ymm(i % 4);
|
|
uni_vpxor(src_ymm_1, src_ymm_1, src_ymm_1);
|
|
}
|
|
}
|
|
if (pass == 0 && ncolumns >= simd_w_)
|
|
mov(reg_src, reg_src_backup);
|
|
|
|
vpunpcklbw(get_ymm(4), get_ymm(0), get_ymm(1));
|
|
vpunpckhbw(get_ymm(5), get_ymm(0), get_ymm(1));
|
|
vpunpcklbw(get_ymm(0), get_ymm(2), get_ymm(3));
|
|
vpunpckhbw(get_ymm(1), get_ymm(2), get_ymm(3));
|
|
|
|
vpunpcklwd(get_ymm(2), get_ymm(4), get_ymm(0));
|
|
vpunpckhwd(get_ymm(3), get_ymm(4), get_ymm(0));
|
|
vpunpcklwd(get_ymm(4), get_ymm(5), get_ymm(1));
|
|
vpunpckhwd(get_ymm(5), get_ymm(5), get_ymm(1));
|
|
}
|
|
|
|
auto get_accum
|
|
= [&](int idx) { return get_comp_acc(idx + pass * 4); };
|
|
|
|
if (!zeropad
|
|
&& IMPLICATION(pass == 1,
|
|
ncolumns > 32)) { // check against {0, 32}
|
|
vperm2i128(get_ymm(0), get_ymm(2), get_ymm(3), perm2i128_l);
|
|
vperm2i128(get_ymm(1), get_ymm(4), get_ymm(5), perm2i128_l);
|
|
uni_vmovups(ptr[reg_tr_src + set_1_tr_src_offset], get_ymm(0));
|
|
uni_vmovups(ptr[reg_tr_src + set_1_tr_src_offset + simd_w_],
|
|
get_ymm(1));
|
|
if (do_compute_compensation_) {
|
|
vpdpbusd(get_accum(0), vmm_comp_mul, get_ymm(0),
|
|
VexEncoding);
|
|
vpdpbusd(get_accum(1), vmm_comp_mul, get_ymm(1),
|
|
VexEncoding);
|
|
}
|
|
} else if (conf_->wei_n_blk > 32) {
|
|
uni_vmovups(ptr[reg_tr_src + set_1_tr_src_offset], vmm_zero);
|
|
uni_vmovups(ptr[reg_tr_src + set_1_tr_src_offset + simd_w_],
|
|
vmm_zero);
|
|
}
|
|
|
|
const int set_2_tr_src_offset = set_1_tr_src_offset + n_blk_step_;
|
|
const int upper_check = 16 + pass * 32; // check against {16, 48}
|
|
if (!zeropad && ncolumns > upper_check) {
|
|
vperm2i128(get_ymm(2), get_ymm(2), get_ymm(3), perm2i128_h);
|
|
vperm2i128(get_ymm(3), get_ymm(4), get_ymm(5), perm2i128_h);
|
|
uni_vmovups(ptr[reg_tr_src + set_2_tr_src_offset], get_ymm(2));
|
|
uni_vmovups(ptr[reg_tr_src + set_2_tr_src_offset + simd_w_],
|
|
get_ymm(3));
|
|
if (do_compute_compensation_) {
|
|
vpdpbusd(get_accum(2), vmm_comp_mul, get_ymm(2),
|
|
VexEncoding);
|
|
vpdpbusd(get_accum(3), vmm_comp_mul, get_ymm(3),
|
|
VexEncoding);
|
|
}
|
|
} else if (conf_->wei_n_blk > upper_check) {
|
|
uni_vmovups(ptr[reg_tr_src + set_2_tr_src_offset], vmm_zero);
|
|
uni_vmovups(ptr[reg_tr_src + set_2_tr_src_offset + simd_w_],
|
|
vmm_zero);
|
|
}
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_int8_t<Vmm>::generate() {
|
|
preamble();
|
|
sub(rsp, stack_space_needed_);
|
|
|
|
if (avx512_core_dot_product_) {
|
|
mov(reg_tmp.cvt16(), 1);
|
|
vpbroadcastw(vmm_ones_words, reg_tmp.cvt16());
|
|
}
|
|
|
|
uni_vpxor(vmm_zero, vmm_zero, vmm_zero);
|
|
mov(reg_src, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
|
|
if (is_dynamic_stride_) {
|
|
mov(reg_src_stride, ptr[param1 + GET_OFF(dynamic_src_stride)]);
|
|
}
|
|
|
|
init_permute();
|
|
|
|
if (do_compute_compensation_) {
|
|
int n_iters = div_up(conf_->wei_n_blk, 16) * (is_ymm_ ? 2 : 1);
|
|
for (int i = 0; i < n_iters; i++)
|
|
uni_vpxor(get_comp_acc(i), get_comp_acc(i), get_comp_acc(i));
|
|
mov(reg_tmp, 1);
|
|
uni_vpbroadcastb(vmm_comp_mul, reg_tmp.cvt8());
|
|
}
|
|
|
|
auto compute_K_loop_body = [&](const reg64_t ®_K, int ncolumns,
|
|
bool is_N_tail, bool zeropad) {
|
|
const int k_unroll = 4;
|
|
Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done;
|
|
cmp(reg_K, k_unroll * k_blk_step_);
|
|
jl(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_unrolled);
|
|
copy_block(k_unroll * k_blk_step_, ncolumns, is_N_tail, zeropad);
|
|
if (!zeropad && !is_dynamic_stride_)
|
|
add(reg_src, k_unroll * k_blk_step_ * src_stride_);
|
|
add(reg_tr_src, k_unroll * tr_src_stride_);
|
|
|
|
sub(reg_K, k_unroll * k_blk_step_);
|
|
cmp(reg_K, k_unroll * k_blk_step_);
|
|
jge(K_loop_unrolled, T_NEAR);
|
|
|
|
L(K_loop_single);
|
|
cmp(reg_K, k_blk_step_);
|
|
jl(K_loop_tail_or_done, T_NEAR);
|
|
|
|
copy_block(k_blk_step_, ncolumns, is_N_tail, zeropad);
|
|
if (!zeropad && !is_dynamic_stride_)
|
|
add(reg_src, k_blk_step_ * src_stride_);
|
|
add(reg_tr_src, tr_src_stride_);
|
|
|
|
sub(reg_K, k_blk_step_);
|
|
jmp(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_tail_or_done);
|
|
|
|
int k_blk_tail = conf_->K % k_blk_step_;
|
|
if (k_blk_tail > 0) {
|
|
Label K_loop_done;
|
|
cmp(reg_K, 0);
|
|
jle(K_loop_done, T_NEAR);
|
|
|
|
copy_block(k_blk_tail, ncolumns, is_N_tail, zeropad);
|
|
add(reg_tr_src, tr_src_stride_);
|
|
sub(reg_K, k_blk_tail);
|
|
L(K_loop_done);
|
|
}
|
|
};
|
|
|
|
auto compute_K_loop = [&](bool is_N_tail) {
|
|
int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk;
|
|
// 'param1' register (rcx on Windows) re-written in compute_K_loop_body
|
|
// so we need to read and keep 'current_K_pad' parameter in stack before
|
|
// the call
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_pad)]);
|
|
mov(ptr[rsp + reg_current_K_pad_offs_], reg_K_iters);
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, false);
|
|
mov(reg_K_iters, ptr[rsp + reg_current_K_pad_offs_]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, true);
|
|
};
|
|
|
|
Label done;
|
|
cmp(reg_N_blk, 0);
|
|
jle(done, T_NEAR);
|
|
|
|
if (conf_->N_tail > 0 || is_dynamic_N_) {
|
|
Label main_N_blk;
|
|
cmp(reg_N_blk, conf_->N_blk);
|
|
je(main_N_blk, T_NEAR);
|
|
compute_K_loop(true);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(main_N_blk);
|
|
}
|
|
|
|
compute_K_loop(false);
|
|
L(done);
|
|
|
|
if (do_compute_compensation_) {
|
|
const bool req_s8s8_comp = conf_->s8s8_compensation_required;
|
|
const bool req_zp_comp = conf_->has_zero_point_a;
|
|
int n_iters = div_up(conf_->wei_n_blk, 16);
|
|
assert(IMPLICATION(req_zp_comp,
|
|
conf_->src_zp_type == brgemm_broadcast_t::per_tensor));
|
|
|
|
if (req_s8s8_comp)
|
|
mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]);
|
|
if (req_zp_comp)
|
|
mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]);
|
|
mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
|
|
|
|
// YMM Note: 16 vmm registers would be needed, so only compute by halves
|
|
const bool do_outer_unroll = req_s8s8_comp;
|
|
const int outer_unroll = is_ymm_ && do_outer_unroll ? 2 : 1;
|
|
const int inner_unroll = is_ymm_ && (!do_outer_unroll) ? 2 : 1;
|
|
for (int out_ur = 0; out_ur < outer_unroll; ++out_ur) {
|
|
|
|
// copy 'comp_acc' into s8s8_comp accumulator
|
|
if (req_s8s8_comp) {
|
|
for (int i = 0; i < n_iters; i++) {
|
|
const int accum_idx = i + out_ur * n_iters;
|
|
uni_vmovups(get_vmm_wei_scale_comp_res(i),
|
|
get_comp_acc(accum_idx));
|
|
}
|
|
}
|
|
|
|
Label skip_acc, store;
|
|
cmp(reg_K_start, 0);
|
|
je(skip_acc, T_NEAR);
|
|
if (req_s8s8_comp) {
|
|
for (int i = 0; i < n_iters; i++) {
|
|
const int idx = i + out_ur * n_iters;
|
|
const auto vmm_acc = get_comp_acc(idx);
|
|
const auto vmm_res = get_vmm_wei_scale_comp_res(i);
|
|
const auto addr = !is_ymm_
|
|
? EVEX_compress_addr(reg_comp_ptr, idx * simd_w_)
|
|
: ptr[reg_comp_ptr + idx * simd_w_];
|
|
uni_vpaddd(vmm_res, vmm_acc, addr);
|
|
}
|
|
}
|
|
|
|
if (req_zp_comp) {
|
|
for_(int i = 0; i < n_iters; i++)
|
|
for (int in_ur = 0; in_ur < inner_unroll; ++in_ur) {
|
|
const int idx = i * inner_unroll + in_ur + out_ur * n_iters;
|
|
const auto vmm_acc = get_comp_acc(idx);
|
|
const auto vmm_res = get_vmm_zp_comp_res(idx);
|
|
const auto addr = !is_ymm_
|
|
? EVEX_compress_addr(reg_zp_comp_ptr, idx * simd_w_)
|
|
: ptr[reg_zp_comp_ptr + idx * simd_w_];
|
|
uni_vpaddd(vmm_res, vmm_acc, addr);
|
|
}
|
|
}
|
|
|
|
L(skip_acc);
|
|
cmp(reg_K_start, rnd_up(conf_->K, conf_->K_blk) - conf_->K_blk);
|
|
jl(store, T_NEAR);
|
|
|
|
if (req_s8s8_comp) {
|
|
mov(reg_tmp, 0xffffffff);
|
|
const auto vmm_all_bits_1 = vmm_comp_mul;
|
|
uni_vpbroadcastd(vmm_all_bits_1, reg_tmp.cvt32());
|
|
mov(reg_tmp, 0x1);
|
|
const auto vmm_one_s32 = vmm_zero;
|
|
uni_vpbroadcastd(vmm_one_s32, reg_tmp.cvt32());
|
|
|
|
for (int i = 0; i < n_iters; i++) {
|
|
const auto vmm_res = get_vmm_wei_scale_comp_res(i);
|
|
// multiply by 128
|
|
uni_vpslld(vmm_res, vmm_res, 7);
|
|
// change sign
|
|
uni_vpandnd(vmm_res, vmm_res, vmm_all_bits_1);
|
|
uni_vpaddd(vmm_res, vmm_res, vmm_one_s32);
|
|
}
|
|
}
|
|
|
|
if (req_zp_comp) {
|
|
mov(reg_zp_a_neg_val_ptr,
|
|
ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]);
|
|
const auto vmm_zp_a_neg_val = vmm_zero;
|
|
uni_vbroadcastss(vmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]);
|
|
|
|
for_(int i = 0; i < n_iters; i++)
|
|
for (int in_ur = 0; in_ur < inner_unroll; ++in_ur) {
|
|
const int idx = i * inner_unroll + in_ur + out_ur * n_iters;
|
|
const auto vmm_res = get_vmm_zp_comp_res(idx);
|
|
uni_vpmulld(vmm_res, vmm_res, vmm_zp_a_neg_val);
|
|
}
|
|
}
|
|
|
|
L(store);
|
|
if (req_s8s8_comp) {
|
|
for (int i = 0; i < n_iters; i++) {
|
|
const auto vmm_res = get_vmm_wei_scale_comp_res(i);
|
|
const int idx_offset = i + out_ur * n_iters;
|
|
const auto addr = !is_ymm_
|
|
? EVEX_compress_addr(
|
|
reg_comp_ptr, idx_offset * simd_w_)
|
|
: ptr[reg_comp_ptr + idx_offset * simd_w_];
|
|
uni_vmovups(addr, vmm_res);
|
|
}
|
|
}
|
|
if (req_zp_comp) {
|
|
for_(int i = 0; i < n_iters; i++)
|
|
for (int in_ur = 0; in_ur < inner_unroll; ++in_ur) {
|
|
const int idx = i * inner_unroll + in_ur + out_ur * n_iters;
|
|
const auto vmm_res = get_vmm_zp_comp_res(idx);
|
|
const auto addr = !is_ymm_
|
|
? EVEX_compress_addr(reg_zp_comp_ptr, idx * simd_w_)
|
|
: ptr[reg_zp_comp_ptr + idx * simd_w_];
|
|
uni_vmovups(addr, vmm_res);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
add(rsp, stack_space_needed_);
|
|
postamble();
|
|
}
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_b_bf16_t
|
|
: public jit_brgemm_matmul_copy_b_common_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_bf16_t)
|
|
|
|
jit_brgemm_matmul_copy_b_bf16_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
|
, typesize(conf->b_dt_sz)
|
|
, tr_typesize(conf->tr_b_dt_sz)
|
|
, wei_scales_typesize(conf->wei_scales_dt_sz)
|
|
, src_stride(conf->copy_B_wei_stride)
|
|
, tr_src_stride(conf_->LDB * k_blk_step * tr_typesize)
|
|
, is_src_int4(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4))
|
|
, is_dynamic_stride(is_runtime_value(src_stride))
|
|
, is_dynamic_N(conf->is_runtime_N)
|
|
, do_N_loop(conf->LDB < conf->N_blk)
|
|
, req_cvtps2bf16(conf->is_bf32 || conf->is_bf16_with_int_wei)
|
|
, req_zp_b_shift(conf->has_zero_point_b && conf->with_wei_decompression)
|
|
, req_apply_wei_scales(conf->apply_scales_in_buffer_b)
|
|
, is_wei_grouped_over_k(
|
|
conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k)
|
|
, elems_per_byte(is_src_int4 ? 2 : 1) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
using zmm = const Xbyak::Zmm;
|
|
using ymm = const Xbyak::Ymm;
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
|
|
enum { k_blk_step = 2, n_blk_step = 16 };
|
|
const int typesize, tr_typesize, wei_scales_typesize;
|
|
const dim_t src_stride, tr_src_stride;
|
|
const bool is_src_int4;
|
|
const bool is_dynamic_stride;
|
|
const bool is_dynamic_N;
|
|
const bool do_N_loop;
|
|
const bool req_cvtps2bf16;
|
|
const bool req_zp_b_shift;
|
|
const bool req_apply_wei_scales;
|
|
const bool is_wei_grouped_over_k;
|
|
const dim_t elems_per_byte;
|
|
|
|
constexpr static int reg_src_offs = 0;
|
|
constexpr static int reg_tr_src_offs = 8;
|
|
|
|
constexpr static int reg_k_iters_offs_ = 16;
|
|
constexpr static int reg_current_K_pad_offs_ = 24;
|
|
|
|
constexpr static int reg_K_start_offs_ = 32;
|
|
|
|
constexpr static int stack_space_needed = 40;
|
|
|
|
reg64_t reg_src = rax;
|
|
reg64_t reg_tr_src = rbx;
|
|
|
|
reg64_t reg_K_iters = r8;
|
|
reg64_t reg_N_blk = r9;
|
|
reg64_t reg_K_start = r10;
|
|
reg64_t reg_src_stride = r11;
|
|
reg64_t reg_src_stride_x2 = r12;
|
|
reg64_t reg_src_load_0 = r13;
|
|
reg64_t reg_src_load_1 = r14;
|
|
reg64_t reg_tmp = r15;
|
|
|
|
reg64_t reg_copy_block_n_shift = rsi;
|
|
reg64_t reg_wei_scales = rdx;
|
|
reg64_t reg_zp_ptr = r13;
|
|
|
|
reg64_t reg_dynamic_tail = rcx;
|
|
Xbyak::Reg8 reg8_mask_shift = reg_dynamic_tail.cvt8();
|
|
|
|
Vmm vmm_zero = Vmm(0);
|
|
Vmm vmm_permw = Vmm(1);
|
|
Vmm vmm_tmp = Vmm(1); // used only for avx2_vnni_2
|
|
Vmm vmm_zp_b_shift = Vmm(2);
|
|
Vmm vmm_permd = Vmm(3);
|
|
Vmm vmm_wei_scales = Vmm(4);
|
|
|
|
void kmovx(Opmask k, unsigned w) {
|
|
if (!isa_has_masks(conf_->isa)) return;
|
|
const auto regw_tmp = reg_tmp.cvt32();
|
|
if (is_dynamic_N) {
|
|
mov(reg_tmp, 1);
|
|
shl(reg_tmp, reg8_mask_shift /* reg_dynamic_tail.cvt8() == cl */);
|
|
sub(reg_tmp, 1);
|
|
} else
|
|
mov(regw_tmp, w);
|
|
if (req_cvtps2bf16)
|
|
jit_generator_t::kmovw(k, regw_tmp);
|
|
else
|
|
jit_generator_t::kmovd(k, regw_tmp);
|
|
}
|
|
|
|
void copy_block(int nrows, int ncolumns, bool n_tail, bool zeropad);
|
|
void copy_2x32(int nrows, int ncolumns, bool zeropad);
|
|
void init_masks();
|
|
void generate() override;
|
|
};
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_2x32(
|
|
int nrows, int ncolumns, bool zeropad) {
|
|
|
|
const int columns_tail = ncolumns % n_blk_step;
|
|
if (columns_tail > 0 && columns_tail < n_blk_step) {
|
|
const auto tail_mask = (1 << columns_tail) - 1;
|
|
kmovx(kTail, tail_mask);
|
|
if (is_src_int4) {
|
|
const auto int4_tail_mask = (1 << (columns_tail / 2)) - 1;
|
|
kmovx(kTail_int4, int4_tail_mask);
|
|
}
|
|
}
|
|
|
|
static constexpr int blk_sz = k_blk_step;
|
|
const int reserved_regs = req_apply_wei_scales ? 5
|
|
: is_src_int4 ? 4
|
|
: req_zp_b_shift ? 3
|
|
: 2;
|
|
const int max_isa_regs = isa_num_vregs(conf_->isa);
|
|
const int max_regs_available = max_isa_regs - reserved_regs;
|
|
const int max_unroll = max_regs_available / blk_sz;
|
|
|
|
auto get_vmm = [max_unroll, max_isa_regs, reserved_regs](int blk, int idx) {
|
|
assert(idx >= 0 && idx < blk_sz && blk >= 0);
|
|
auto reg_idx = reserved_regs + max_unroll * ((idx + 1) % blk_sz) + blk;
|
|
UNUSED(max_isa_regs);
|
|
assert(reg_idx >= reserved_regs && reg_idx < max_isa_regs);
|
|
return Vmm(reg_idx);
|
|
};
|
|
|
|
/** Loads zero points, when is_wei_zp_per_n is set.
|
|
* Zeropoints size over N dimension always equals to N.
|
|
*/
|
|
auto load_zero_point = [this, ncolumns, columns_tail](int n) {
|
|
if (!conf_->is_wei_zp_per_n) return;
|
|
const bool is_tail = (ncolumns - n) < n_blk_step;
|
|
const auto zp_dt = conf_->wei_zp_dt;
|
|
const auto zp_dt_sz = types::data_type_size(zp_dt);
|
|
const auto elems_per_byte
|
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
|
const auto offset = n * zp_dt_sz / elems_per_byte;
|
|
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
|
|
if (is_tail && !isa_has_masks(conf_->isa)) {
|
|
load_bytes(vmm_zp_b_shift, addr, columns_tail / elems_per_byte);
|
|
load_value(vmm_zp_b_shift, vmm_zp_b_shift, vmm_permd, zp_dt);
|
|
}
|
|
load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail);
|
|
};
|
|
|
|
/** Loads scales, when is_wei_scale_per_n is set.
|
|
* Scales size over N dimension always equals to N.
|
|
*/
|
|
auto load_scales = [this, ncolumns, columns_tail](int n) {
|
|
if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b)
|
|
return;
|
|
|
|
const bool is_tail = (ncolumns - n) < n_blk_step;
|
|
const auto &scales_dt = conf_->wei_scales_dt;
|
|
const auto scales_dt_sz = types::data_type_size(scales_dt);
|
|
const auto offset = n * scales_dt_sz;
|
|
const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset);
|
|
if (is_tail && !isa_has_masks(conf_->isa)) {
|
|
load_bytes(
|
|
vmm_wei_scales, addr, columns_tail * wei_scales_typesize);
|
|
load_scale_value(vmm_wei_scales, vmm_wei_scales, scales_dt,
|
|
/*is_tail=*/false);
|
|
}
|
|
load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail);
|
|
};
|
|
|
|
auto load = [this, get_vmm, ncolumns, columns_tail, load_scales,
|
|
load_zero_point](int blk, int k, int n) {
|
|
auto src_reg = get_vmm(blk, k % k_blk_step);
|
|
const bool is_tail = ncolumns - n < n_blk_step;
|
|
auto src_load = maybe_mask(src_reg, is_tail);
|
|
const auto offset
|
|
= ((is_dynamic_stride ? 0 : k * src_stride) + (n * typesize))
|
|
/ elems_per_byte;
|
|
const auto reg_src_load
|
|
= is_dynamic_stride && k % 2 != 0 ? reg_src_load_1 : reg_src;
|
|
auto load_addr = maybe_EVEX_compress_addr(reg_src_load, offset);
|
|
if (!isa_has_masks(conf_->isa)) {
|
|
if (is_tail)
|
|
load_bytes(src_load, load_addr, columns_tail * tr_typesize);
|
|
else
|
|
uni_vmovups(src_load, load_addr);
|
|
} else {
|
|
load_value(
|
|
src_reg, load_addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
|
}
|
|
load_zero_point(n);
|
|
load_scales(n);
|
|
decompress_and_downcvt_reg(src_reg, vmm_zp_b_shift, vmm_wei_scales,
|
|
conf_->orig_wei_dt, conf_->wei_dt);
|
|
};
|
|
|
|
/** Stores half of the block using mask for the case when vnni_granularity == 2 */
|
|
auto store_half_block = [&](const Vmm &src_vmm0, const Vmm &src_vmm1,
|
|
const Xbyak::Address &store_addr) {
|
|
const auto zmm1 = zmm(src_vmm1.getIdx());
|
|
const auto zmm0 = zmm(src_vmm0.getIdx());
|
|
uni_vxorps(zmm1, zmm1, zmm1);
|
|
//if k % 2 == 1 then save only odd indices
|
|
// otherwise: using only even indices
|
|
Label even_k, end_permute;
|
|
mov(reg_tmp, ptr[rsp + reg_K_start_offs_]);
|
|
test(reg_tmp, 1);
|
|
jz(even_k, T_NEAR);
|
|
vinsertf64x4(zmm0, zmm1, ymm(src_vmm0.getIdx()), 1);
|
|
vpermw(zmm0, vmm_permw, zmm0);
|
|
uni_vmovdqu16(store_addr | kAAAA, zmm0);
|
|
jmp(end_permute);
|
|
L(even_k);
|
|
vinsertf64x4(zmm0, zmm1, ymm(src_vmm0.getIdx()), 0);
|
|
vpermw(zmm0, vmm_permw, zmm0);
|
|
uni_vmovdqu16(store_addr, zmm0);
|
|
L(end_permute);
|
|
};
|
|
|
|
// The case when it's required to store half block
|
|
// When grouped over K weights and K == 1
|
|
const auto kernel_early_stop = is_wei_grouped_over_k && nrows == 1;
|
|
|
|
int iter = 0;
|
|
int n_iters;
|
|
if (is_dynamic_N || do_N_loop) {
|
|
n_iters = ncolumns;
|
|
} else {
|
|
n_iters = conf_->wei_n_blk;
|
|
}
|
|
|
|
for_(int k = 0; k < nrows; k += k_blk_step)
|
|
for (int n = 0; n < n_iters; n += n_blk_step) {
|
|
const int k_blk = k / k_blk_step;
|
|
const dim_t tr_src_off
|
|
= k_blk * tr_src_stride + n * k_blk_step * tr_typesize;
|
|
|
|
const auto store_addr
|
|
= maybe_EVEX_compress_addr(reg_tr_src, tr_src_off);
|
|
const auto store_addr_ymm1
|
|
= ptr[reg_tr_src + tr_src_off + vreg_traits_t<Vmm>::vlen];
|
|
const int blk_idx = iter % max_unroll;
|
|
const auto src_vmm0 = get_vmm(blk_idx, 0);
|
|
const auto src_zmm0 = zmm(src_vmm0.getIdx());
|
|
const auto src_vmm1 = get_vmm(blk_idx, 1);
|
|
if (is_dynamic_stride && n == 0) {
|
|
if (k == 0) {
|
|
mov(reg_src_load_1, reg_src);
|
|
add(reg_src_load_1, reg_src_stride);
|
|
} else {
|
|
add(reg_src, reg_src_stride_x2);
|
|
add(reg_src_load_1, reg_src_stride_x2);
|
|
}
|
|
}
|
|
|
|
if (ncolumns - n <= 0 || zeropad) {
|
|
uni_vmovups(store_addr, vmm_zero);
|
|
if (!is_superset(conf_->isa, avx512_core))
|
|
uni_vmovups(store_addr_ymm1, vmm_zero);
|
|
continue;
|
|
}
|
|
|
|
load(blk_idx, k, n);
|
|
|
|
// Store only half block
|
|
if (kernel_early_stop) {
|
|
store_half_block(src_vmm0, src_vmm1, store_addr);
|
|
iter++;
|
|
continue;
|
|
}
|
|
|
|
// Load second K half blk and downconvert if required.
|
|
if (nrows - k >= k_blk_step)
|
|
load(blk_idx, k + 1, n);
|
|
else
|
|
uni_vxorps(src_vmm1, src_vmm1, src_vmm1);
|
|
|
|
if (is_superset(conf_->isa, avx512_core)) {
|
|
const auto src_ymm1 = ymm(src_vmm1.getIdx());
|
|
vinsertf64x4(src_zmm0, src_zmm0, src_ymm1, 1);
|
|
vpermw(src_zmm0, vmm_permw, src_zmm0);
|
|
uni_vmovups(store_addr, src_zmm0);
|
|
} else {
|
|
assert(is_superset(conf_->isa, avx2));
|
|
vpunpcklwd(vmm_tmp, src_vmm0, src_vmm1);
|
|
vpunpckhwd(src_vmm1, src_vmm0, src_vmm1);
|
|
vperm2i128(src_vmm0, vmm_tmp, src_vmm1, 0x20);
|
|
vperm2i128(src_vmm1, vmm_tmp, src_vmm1, 0x31);
|
|
uni_vmovups(store_addr, src_vmm0);
|
|
uni_vmovups(store_addr_ymm1, src_vmm1);
|
|
}
|
|
|
|
iter++;
|
|
}
|
|
if (is_dynamic_stride && nrows > 0) {
|
|
add(reg_src, nrows % 2 == 0 ? reg_src_stride_x2 : reg_src_stride);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::init_masks() {
|
|
alignas(64) static constexpr const int16_t bf16_vnni_permute[32]
|
|
= {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, 8, 24, 9,
|
|
25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31};
|
|
|
|
if (is_superset(conf_->isa, avx512_core)) {
|
|
kxnorw(kFFFF, kFFFF, kFFFF); // 1111 1111 1111 1111
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(bf16_vnni_permute));
|
|
vmovdqa64(vmm_permw, ptr[reg_tmp]);
|
|
|
|
if (isa_has_masks(conf_->isa)) {
|
|
// 64-bit mask is also used when is_wei_[zp\scales]_per_k
|
|
mov(reg_tmp, 0xAAAAAAAAAAAAAAAA);
|
|
kmovq(kAAAA, reg_tmp);
|
|
mov(reg_tmp, 0x5555555555555555);
|
|
kmovq(k5555, reg_tmp);
|
|
}
|
|
|
|
if (is_src_int4) {
|
|
alignas(64) static constexpr const uint32_t int4_permute[16]
|
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
|
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
|
|
vmovdqa32(vmm_permd, ptr[reg_tmp]);
|
|
}
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::copy_block(
|
|
int nrows, int ncolumns, bool n_tail, bool zeropad) {
|
|
if (!do_N_loop && (!is_dynamic_N || !n_tail)) {
|
|
copy_2x32(nrows, ncolumns, zeropad);
|
|
return;
|
|
}
|
|
|
|
mov(reg_dynamic_tail, reg_N_blk);
|
|
// dynamic tail processing: main loop with ncolumns = n_blk_step and
|
|
// finally process tail < n_blk_step with dynamically computed mask
|
|
// NOTE: for dynamic_stride case copy_2x32() shifts reg_src pointer
|
|
// so we need to backup/restore its value for every iteration wrt n
|
|
// except the last one
|
|
|
|
mov(ptr[rsp + reg_tr_src_offs], reg_tr_src);
|
|
xor_(reg_copy_block_n_shift, reg_copy_block_n_shift);
|
|
|
|
int current_n_blk_step = do_N_loop ? conf_->LDB : n_blk_step;
|
|
|
|
Label loop_row_start, loop_row_tail, loop_row_done;
|
|
cmp(reg_dynamic_tail, current_n_blk_step);
|
|
jl(loop_row_tail, T_NEAR);
|
|
L(loop_row_start);
|
|
{
|
|
mov(ptr[rsp + reg_src_offs], reg_src);
|
|
add(reg_src, reg_copy_block_n_shift);
|
|
copy_2x32(nrows, current_n_blk_step, zeropad);
|
|
|
|
if (do_N_loop) {
|
|
add(reg_tr_src,
|
|
(current_n_blk_step / conf_->LDB) * conf_->LDB2
|
|
* tr_typesize);
|
|
add(reg_src,
|
|
conf_->B_strides[0] == typesize
|
|
? current_n_blk_step * typesize
|
|
: conf_->B_strides[0]);
|
|
add(reg_copy_block_n_shift,
|
|
conf_->B_strides[0] == typesize
|
|
? current_n_blk_step * typesize
|
|
: conf_->B_strides[0]);
|
|
|
|
} else {
|
|
add(reg_src, current_n_blk_step * typesize);
|
|
add(reg_tr_src, current_n_blk_step * k_blk_step * tr_typesize);
|
|
add(reg_copy_block_n_shift, current_n_blk_step * typesize);
|
|
}
|
|
|
|
sub(reg_dynamic_tail, current_n_blk_step);
|
|
|
|
cmp(reg_dynamic_tail, 0);
|
|
jle(loop_row_done, T_NEAR);
|
|
|
|
mov(reg_src, ptr[rsp + reg_src_offs]);
|
|
|
|
cmp(reg_dynamic_tail, current_n_blk_step);
|
|
jl(loop_row_tail, T_NEAR);
|
|
|
|
jmp(loop_row_start, T_NEAR);
|
|
}
|
|
|
|
L(loop_row_tail);
|
|
{
|
|
cmp(reg_dynamic_tail, 0);
|
|
jle(loop_row_done, T_NEAR);
|
|
|
|
add(reg_src, reg_copy_block_n_shift);
|
|
if (do_N_loop) {
|
|
copy_2x32(nrows, ncolumns % current_n_blk_step, zeropad);
|
|
} else {
|
|
copy_2x32(nrows, 1 /* to force tail case */, zeropad);
|
|
}
|
|
}
|
|
L(loop_row_done);
|
|
|
|
// restore pointers
|
|
sub(reg_src, reg_copy_block_n_shift);
|
|
mov(reg_tr_src, ptr[rsp + reg_tr_src_offs]);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_bf16_t<Vmm>::generate() {
|
|
assert(tr_typesize == sizeof(bfloat16_t));
|
|
preamble();
|
|
sub(rsp, stack_space_needed);
|
|
uni_vxorps(vmm_zero, vmm_zero, vmm_zero);
|
|
|
|
mov(reg_src, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(ptr[rsp + reg_tr_src_offs], reg_tr_src);
|
|
mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
|
|
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
|
// Due to lack of registers save k_iters and k_pad into stack space
|
|
mov(reg_tmp, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
mov(ptr[rsp + reg_k_iters_offs_], reg_tmp);
|
|
mov(reg_tmp, ptr[param1 + GET_OFF(current_K_pad)]);
|
|
mov(ptr[rsp + reg_current_K_pad_offs_], reg_tmp);
|
|
mov(reg_tmp, ptr[param1 + GET_OFF(current_K_start)]);
|
|
mov(ptr[rsp + reg_K_start_offs_], reg_tmp);
|
|
mov(reg_tmp, 0);
|
|
|
|
if (is_dynamic_stride) {
|
|
mov(reg_src_stride, ptr[param1 + GET_OFF(dynamic_src_stride)]);
|
|
mov(reg_src_stride_x2, ptr[param1 + GET_OFF(dynamic_src_stride)]);
|
|
shl(reg_src_stride_x2, 1);
|
|
}
|
|
|
|
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
|
load_common_zp_value(vmm_zp_b_shift, reg_zp_ptr);
|
|
load_common_scale_value(vmm_wei_scales, reg_wei_scales);
|
|
|
|
init_masks();
|
|
|
|
auto compute_K_loop_body = [&](const reg64_t ®_K, int ncolumns,
|
|
bool is_N_tail, bool zeropad) {
|
|
// Compute special K-loop for per-k attributes
|
|
// Only when k_group_size < k_blk_step
|
|
// Otherwise default K-loop is used
|
|
if (is_wei_grouped_over_k) {
|
|
const int k_group_size = conf_->is_wei_zp_per_k
|
|
? conf_->wei_zp_k_gsize
|
|
: conf_->wei_scales_k_gsize;
|
|
if (k_group_size < k_blk_step) {
|
|
if (zeropad) return;
|
|
copy_block(
|
|
k_group_size, ncolumns, is_N_tail, /*zeropad= */ false);
|
|
return;
|
|
}
|
|
}
|
|
|
|
const int k_unroll = 8;
|
|
Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done;
|
|
|
|
cmp(reg_K, k_unroll * k_blk_step);
|
|
jl(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_unrolled);
|
|
copy_block(k_unroll * k_blk_step, ncolumns, is_N_tail, zeropad);
|
|
|
|
if (!zeropad && !is_dynamic_stride)
|
|
add(reg_src, (k_unroll * k_blk_step * src_stride) / elems_per_byte);
|
|
add(reg_tr_src, k_unroll * tr_src_stride);
|
|
|
|
sub(reg_K, k_unroll * k_blk_step);
|
|
cmp(reg_K, k_unroll * k_blk_step);
|
|
jge(K_loop_unrolled, T_NEAR);
|
|
|
|
L(K_loop_single);
|
|
cmp(reg_K, k_blk_step);
|
|
jl(K_loop_tail_or_done, T_NEAR);
|
|
|
|
copy_block(k_blk_step, ncolumns, is_N_tail, zeropad);
|
|
if (!zeropad && !is_dynamic_stride)
|
|
add(reg_src, (k_blk_step * src_stride) / elems_per_byte);
|
|
add(reg_tr_src, tr_src_stride);
|
|
|
|
sub(reg_K, k_blk_step);
|
|
jmp(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_tail_or_done);
|
|
|
|
int k_blk_tail = conf_->K % k_blk_step;
|
|
if (k_blk_tail > 0) {
|
|
Label K_loop_done;
|
|
cmp(reg_K, 0);
|
|
jle(K_loop_done, T_NEAR);
|
|
|
|
copy_block(k_blk_tail, ncolumns, is_N_tail, zeropad);
|
|
add(reg_tr_src, tr_src_stride);
|
|
sub(reg_K, k_blk_tail);
|
|
L(K_loop_done);
|
|
}
|
|
};
|
|
|
|
auto compute_K_loop = [&](bool is_N_tail) {
|
|
int ncolumns = is_N_tail ? conf_->N_tail : conf_->N_blk;
|
|
// 'param1' register (rcx on Windows) re-written in compute_K_loop_body
|
|
// so we need to read and keep 'current_K_pad' parameter in stack before
|
|
// the call
|
|
mov(reg_K_iters, ptr[rsp + reg_k_iters_offs_]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, false);
|
|
mov(reg_K_iters, ptr[rsp + reg_current_K_pad_offs_]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, is_N_tail, true);
|
|
};
|
|
|
|
Label done;
|
|
cmp(reg_N_blk, 0);
|
|
jle(done, T_NEAR);
|
|
|
|
if (conf_->N_tail > 0 || is_dynamic_N) {
|
|
Label main_N_blk;
|
|
cmp(reg_N_blk, conf_->N_blk);
|
|
je(main_N_blk, T_NEAR);
|
|
compute_K_loop(true);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(main_N_blk);
|
|
}
|
|
|
|
compute_K_loop(false);
|
|
L(done);
|
|
|
|
add(rsp, stack_space_needed);
|
|
postamble();
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_b_bf16_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_b_bf16_t<Ymm>;
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_b_f32_t
|
|
: public jit_brgemm_matmul_copy_b_common_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_f32_t)
|
|
|
|
jit_brgemm_matmul_copy_b_f32_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
|
, dt_in_(conf->orig_wei_dt)
|
|
, simd_w_(vreg_traits_t<Vmm>::vlen / sizeof(float))
|
|
, is_src_f4_(one_of(
|
|
conf->orig_wei_dt, data_type::f4_e2m1, data_type::f4_e3m0))
|
|
, is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4))
|
|
, req_zp_b_shift_(
|
|
conf->has_zero_point_b && conf->with_wei_decompression)
|
|
, req_apply_wei_scales_(conf->apply_scales_in_buffer_b)
|
|
, typesize_in_(types::data_type_size(dt_in_))
|
|
, src_elems_per_byte_(is_src_int4_ || is_src_f4_ ? 2 : 1)
|
|
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
|
, src_stride_(conf_->copy_B_wei_stride)
|
|
, tr_src_stride_(conf_->LDB * typesize_out_) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
|
|
const data_type_t dt_in_;
|
|
const int simd_w_;
|
|
const bool is_src_f4_, is_src_int4_, req_zp_b_shift_, req_apply_wei_scales_;
|
|
const size_t typesize_in_, src_elems_per_byte_, wei_scales_typesize_;
|
|
const size_t typesize_out_ = sizeof(float);
|
|
dim_t src_stride_, tr_src_stride_;
|
|
|
|
reg64_t reg_src = rax;
|
|
reg64_t reg_tr_src = rbx;
|
|
|
|
reg64_t reg_K_iters = r8;
|
|
reg64_t reg_N_blk = r9;
|
|
reg64_t reg_K_start = r10;
|
|
reg64_t reg_tmp = r15;
|
|
reg32_t regw_tmp = r15d;
|
|
reg64_t reg_wei_scales = rdx;
|
|
reg64_t reg_zp_ptr = r11;
|
|
|
|
Vmm vmm_zero = Vmm(0);
|
|
Vmm vmm_wei_scales = Vmm(1);
|
|
Vmm vmm_permd = Vmm(2);
|
|
Vmm vmm_zp_b_shift = Vmm(3);
|
|
Vmm vmm_f4_lut = Vmm(4);
|
|
Ymm ymm_tail_mask = ymm1;
|
|
|
|
inline void kmovw(Opmask k, unsigned w) {
|
|
if (!isa_has_masks(conf_->isa)) return;
|
|
mov(regw_tmp, w);
|
|
jit_generator_t::kmovd(k, regw_tmp);
|
|
}
|
|
|
|
void copy_16_x_n_block(int nrows, int ncolumns);
|
|
void compute_k_loop(int ncolumns);
|
|
void generate() override;
|
|
};
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::copy_16_x_n_block(
|
|
int nrows, int ncolumns) {
|
|
const int max_isa_regs = isa_num_vregs(conf_->isa);
|
|
const int reserved_regs = is_src_f4_ ? 5
|
|
: req_zp_b_shift_ ? 4
|
|
: is_src_int4_ ? 3
|
|
: 2;
|
|
const int max_regs_available = max_isa_regs - reserved_regs;
|
|
|
|
auto get_vmm = [max_regs_available, reserved_regs](int reg_idx) {
|
|
MAYBE_UNUSED(max_regs_available);
|
|
MAYBE_UNUSED(reserved_regs); // some compilers detect it as unused
|
|
assert(reg_idx >= 0 && reg_idx < max_regs_available);
|
|
return Vmm(reg_idx + reserved_regs);
|
|
};
|
|
|
|
auto load = [this, get_vmm, ncolumns](int blk, int k, int n) {
|
|
auto src_vmm = get_vmm(blk);
|
|
const bool is_tail = ncolumns - n < simd_w_;
|
|
auto addr = maybe_EVEX_compress_addr(reg_src,
|
|
(k * src_stride_ + n * typesize_in_) / src_elems_per_byte_);
|
|
if (is_tail && !isa_has_masks(conf_->isa))
|
|
vmaskmovps(src_vmm, ymm_tail_mask, addr);
|
|
else
|
|
load_value(src_vmm, addr, vmm_permd, conf_->orig_wei_dt, is_tail);
|
|
|
|
decompress_reg(maybe_mask(src_vmm, is_tail), vmm_zp_b_shift,
|
|
vmm_wei_scales, conf_->orig_wei_dt);
|
|
};
|
|
|
|
/** Loads zero points, when is_wei_zp_per_n is set.
|
|
* Zeropoints size over N dimension always equals to N.
|
|
*/
|
|
auto load_zero_point = [this, ncolumns](int n) {
|
|
if (!conf_->is_wei_zp_per_n) return;
|
|
|
|
const bool is_tail = (ncolumns - n) < simd_w_;
|
|
const auto zp_dt = conf_->wei_zp_dt;
|
|
const auto zp_dt_sz = types::data_type_size(zp_dt);
|
|
const auto elems_per_byte
|
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
|
const auto offset = n * zp_dt_sz / elems_per_byte;
|
|
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
|
|
load_value(vmm_zp_b_shift, addr, vmm_permd, zp_dt, is_tail);
|
|
};
|
|
|
|
/** Loads scales, when is_wei_scale_per_n is set.
|
|
* Scales size over N dimension always equals to N.
|
|
*/
|
|
auto load_scales = [this, ncolumns](int n) {
|
|
if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b)
|
|
return;
|
|
|
|
const bool is_tail = (ncolumns - n) < simd_w_;
|
|
const auto &scales_dt = conf_->wei_scales_dt;
|
|
const auto scales_dt_sz = types::data_type_size(scales_dt);
|
|
const auto offset = n * scales_dt_sz;
|
|
const auto addr = maybe_EVEX_compress_addr(reg_wei_scales, offset);
|
|
load_scale_value(vmm_wei_scales, addr, scales_dt, is_tail);
|
|
};
|
|
|
|
const int columns_tail = ncolumns % simd_w_;
|
|
if (columns_tail < simd_w_) {
|
|
if (isa_has_masks(conf_->isa)) {
|
|
const auto tail_mask = (1 << columns_tail) - 1;
|
|
kmovw(kTail, tail_mask);
|
|
if (is_src_int4_ || is_src_f4_) {
|
|
const auto tail_mask_4bit
|
|
= (1 << (columns_tail / src_elems_per_byte_)) - 1;
|
|
kmovw(kTail_int4, tail_mask_4bit);
|
|
}
|
|
} else {
|
|
init_f32_avx2_mask_ymm(ymm_tail_mask, reg_tmp, columns_tail);
|
|
}
|
|
}
|
|
|
|
int iter = 0;
|
|
for_(int k = 0; k < nrows; k++)
|
|
for (int n = 0; n < conf_->wei_n_blk; n += simd_w_) {
|
|
const dim_t tr_src_off = k * tr_src_stride_ + n * typesize_out_;
|
|
const auto store_addr
|
|
= maybe_EVEX_compress_addr(reg_tr_src, tr_src_off);
|
|
|
|
const int zero_padding = ncolumns - n;
|
|
if (zero_padding <= 0) {
|
|
uni_vmovups(store_addr, vmm_zero);
|
|
continue;
|
|
}
|
|
|
|
load_zero_point(n);
|
|
load_scales(n);
|
|
const int blk_idx = iter % max_regs_available;
|
|
load(blk_idx, k, n);
|
|
|
|
const auto src_vmm0 = get_vmm(blk_idx);
|
|
uni_vmovups(store_addr, src_vmm0);
|
|
iter++;
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::compute_k_loop(int ncolumns) {
|
|
|
|
auto compute_uni_k_loop = [&](int unroll) {
|
|
Label K_start_label, K_end_label;
|
|
|
|
L(K_start_label);
|
|
cmp(reg_K_iters, unroll);
|
|
jl(K_end_label, T_NEAR);
|
|
|
|
copy_16_x_n_block(unroll, ncolumns);
|
|
add(reg_src, (unroll * src_stride_) / src_elems_per_byte_);
|
|
add(reg_tr_src, unroll * tr_src_stride_);
|
|
|
|
sub(reg_K_iters, unroll);
|
|
jmp(K_start_label, T_NEAR);
|
|
|
|
L(K_end_label);
|
|
};
|
|
|
|
constexpr int k_unroll = 16;
|
|
compute_uni_k_loop(k_unroll);
|
|
compute_uni_k_loop(1);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_f32_t<Vmm>::generate() {
|
|
preamble();
|
|
uni_vxorps(vmm_zero, vmm_zero, vmm_zero);
|
|
|
|
mov(reg_src, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
|
|
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
|
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
|
kmovw(kFFFF, 0xffff); // 1111111111111111
|
|
if (is_src_int4_ || is_src_f4_) {
|
|
alignas(64) static constexpr const uint32_t int4_permute[16]
|
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
|
mov(reg_tmp, reinterpret_cast<size_t>(int4_permute));
|
|
vmovdqa32(vmm_permd, ptr[reg_tmp]);
|
|
|
|
kmovw(kAAAA, 0xaaaa);
|
|
kmovw(k5555, 0x5555);
|
|
}
|
|
if (is_src_f4_) {
|
|
|
|
alignas(64) static constexpr const float f4_e2m1_table[16]
|
|
= {0.0f, .5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, -0.0f, -.5f,
|
|
-1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f};
|
|
alignas(64) static constexpr const float f4_e3m0_table[16]
|
|
= {0.0f, .25f, .5f, 1.0f, 2.0f, 4.0f, 8.0f, 16.0f, -0.0f, -.25f,
|
|
-.5f, -1.0f, -2.0f, -4.0f, -8.0f, -16.0f};
|
|
switch (dt_in_) {
|
|
case data_type::f4_e2m1:
|
|
mov(reg_tmp, reinterpret_cast<size_t>(f4_e2m1_table));
|
|
break;
|
|
case data_type::f4_e3m0:
|
|
mov(reg_tmp, reinterpret_cast<size_t>(f4_e3m0_table));
|
|
break;
|
|
|
|
default: break;
|
|
}
|
|
vmovdqa32(vmm_f4_lut, ptr[reg_tmp]);
|
|
}
|
|
|
|
load_common_zp_value(vmm_zp_b_shift, reg_zp_ptr);
|
|
load_common_scale_value(vmm_wei_scales, reg_wei_scales);
|
|
|
|
Label done;
|
|
if (conf_->N_tail > 0) {
|
|
Label not_N_tail;
|
|
cmp(reg_N_blk, conf_->N_tail);
|
|
jne(not_N_tail, T_NEAR);
|
|
compute_k_loop(conf_->N_tail);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(not_N_tail);
|
|
}
|
|
|
|
compute_k_loop(conf_->N_blk);
|
|
L(done);
|
|
|
|
postamble();
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_b_f32_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_b_f32_t<Ymm>;
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_b_transposed_t
|
|
: public jit_brgemm_matmul_copy_b_common_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_transposed_t)
|
|
|
|
jit_brgemm_matmul_copy_b_transposed_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
|
, typesize_(conf_->b_dt_sz)
|
|
, tr_typesize_(conf_->tr_b_dt_sz)
|
|
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
|
, vnni_granularity_(data_type_vnni_granularity(conf_->wei_dt))
|
|
, k_blk_step_(vlen_ / tr_typesize_)
|
|
, is_wei_grouped_over_k_(
|
|
conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k)
|
|
, do_compute_compensation_(
|
|
conf_->has_zero_point_a || conf_->s8s8_compensation_required)
|
|
, is_bf32_(conf->is_bf32)
|
|
, is_bf16_with_int_wei_(conf->is_bf16_with_int_wei)
|
|
, is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4))
|
|
, req_cvtps2xf16_(conf->is_bf32 || conf->is_bf16_with_int_wei
|
|
|| (conf->is_f16_with_int_wei
|
|
&& conf->wei_dt == data_type::f16))
|
|
, req_zp_comp_(conf_->has_zero_point_a)
|
|
, req_s8s8_comp_(conf_->s8s8_compensation_required)
|
|
, req_zp_b_shift_(
|
|
conf_->has_zero_point_b && conf_->with_wei_decompression)
|
|
, req_apply_wei_scales_(conf_->apply_scales_in_buffer_b)
|
|
, avx512_core_dot_product_(
|
|
do_compute_compensation_ && !isa_has_int8_vnni(conf->isa))
|
|
// See the note in `create_brgemm_matmul_copy_b` why `orig_wei_dt` used.
|
|
, use_fp16_instructions_(is_subset(conf_->isa, avx512_core_fp16)
|
|
&& conf_->orig_wei_dt == data_type::f16
|
|
&& conf_->wei_dt == data_type::f32)
|
|
// This variable is responsible for enabling to upconversion from bf16
|
|
// to f32 similarly to f16, mostly for proper tail handling.
|
|
, use_bf16_instructions_(is_subset(conf_->isa, avx512_core_bf16)
|
|
&& conf_->orig_wei_dt == data_type::bf16
|
|
&& conf_->wei_dt == data_type::f32)
|
|
, max_tmp_idx(16
|
|
- (avx512_core_dot_product_
|
|
? 8
|
|
: (do_compute_compensation_ ? 6
|
|
: is_src_int4_ ? 2
|
|
: req_zp_b_shift_ ? 1
|
|
: 0)))
|
|
, src_stride_(conf_->copy_B_wei_stride)
|
|
, tr_src_stride_(conf_->LDB * vnni_granularity_ * tr_typesize_)
|
|
, src_elems_per_byte_(is_src_int4_ ? 2 : 1)
|
|
, is_dynamic_N_(conf->is_runtime_N) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
|
|
static constexpr bool is_ymm_ = std::is_same<Vmm, Xbyak::Ymm>::value;
|
|
static constexpr cpu_isa_t isa_ = is_ymm_ ? avx2 : avx512_core;
|
|
static constexpr int max_vmm_regs_ = cpu_isa_traits_t<isa_>::n_vregs;
|
|
static constexpr int vlen_ = vreg_traits_t<Vmm>::vlen;
|
|
static constexpr int n_blk_step_ = is_ymm_ ? 8 : 16;
|
|
static constexpr int req_cvt_bf16_k_blk_step_ = 16;
|
|
static constexpr size_t comp_shift_ = vlen_;
|
|
|
|
const int typesize_;
|
|
const int tr_typesize_;
|
|
const int wei_scales_typesize_;
|
|
const int vnni_granularity_;
|
|
const int k_blk_step_;
|
|
const bool is_wei_grouped_over_k_;
|
|
const bool do_compute_compensation_;
|
|
const bool is_bf32_;
|
|
const bool is_bf16_with_int_wei_;
|
|
const bool is_src_int4_;
|
|
const bool req_cvtps2xf16_;
|
|
const bool req_zp_comp_;
|
|
const bool req_s8s8_comp_;
|
|
const bool req_zp_b_shift_;
|
|
const bool req_apply_wei_scales_;
|
|
const bool avx512_core_dot_product_;
|
|
const bool use_fp16_instructions_;
|
|
const bool use_bf16_instructions_;
|
|
const int max_tmp_idx;
|
|
|
|
const dim_t src_stride_, tr_src_stride_, src_elems_per_byte_;
|
|
const bool is_dynamic_N_;
|
|
|
|
constexpr static int ldb_step_idx_offs = 0;
|
|
constexpr static int stack_space_needed = 8;
|
|
|
|
reg64_t reg_src_base = rax;
|
|
reg64_t reg_tr_src_base = rbx;
|
|
reg64_t reg_comp_ptr = rdx;
|
|
reg64_t reg_zp_ptr = rdx;
|
|
|
|
reg64_t reg_K_iters = r8;
|
|
reg64_t reg_N_iters = r9;
|
|
reg64_t reg_src = r10;
|
|
reg64_t reg_tr_src = r11;
|
|
reg64_t reg_zp_comp_ptr = r12;
|
|
reg64_t reg_zp_a_neg_val_ptr = r13;
|
|
reg64_t reg_K_start = r14;
|
|
reg64_t reg_wei_scales = rsi;
|
|
|
|
reg64_t regq_tmp = r15;
|
|
reg32_t regw_tmp = r15d;
|
|
reg64_t imm_addr64 = abi_not_param1;
|
|
|
|
// Note: for the AVX2 implementation, reserve Ymm(8) and Ymm(9) as
|
|
// temporary compute registers.
|
|
Vmm vmm_comp_mul = Vmm(max_vmm_regs_ - 1);
|
|
Vmm vmm_comp_acc = Vmm(max_vmm_regs_ - 2);
|
|
Vmm vmm_zp_a_neg_val = Vmm(max_vmm_regs_ - 3);
|
|
Vmm vmm_s8s8_comp_acc = Vmm(max_vmm_regs_ - 4);
|
|
Vmm vmm_all_bits_1 = Vmm(max_vmm_regs_ - 5);
|
|
Vmm vmm_one_s32 = Vmm(max_vmm_regs_ - 6);
|
|
|
|
// Required in every dot product for INT8 non-VNNI computation.
|
|
Vmm vmm_ones_words = Vmm(max_vmm_regs_ - 7);
|
|
Vmm vmm_dot_product_temp = Vmm(max_vmm_regs_ - 8);
|
|
|
|
Vmm vmm_zp_b_val = Vmm(max_vmm_regs_ - 1);
|
|
Vmm vmm_permd = Vmm(max_vmm_regs_ - 2);
|
|
// Collide with `vmm_zp_a_neg_val` as they shouldn't intersect in
|
|
// functionality.
|
|
Vmm vmm_wei_scales = Vmm(max_vmm_regs_ - 3);
|
|
|
|
void kmovw(Opmask k, unsigned w) {
|
|
mov(regw_tmp, w);
|
|
jit_generator_t::kmovw(k, regw_tmp);
|
|
};
|
|
|
|
void kmovq(Opmask k, size_t q) {
|
|
mov(regq_tmp, q);
|
|
jit_generator_t::kmovq(k, regq_tmp);
|
|
};
|
|
|
|
Vmm src_vmm(int i) {
|
|
assert(i >= 0 && i < n_blk_step_);
|
|
return Vmm(i);
|
|
}
|
|
|
|
Vmm tmp_vmm(int i) {
|
|
// If compensation compute is required - last 6 zmms are reserved for it
|
|
assert(i >= 0 && IMPLICATION(!is_ymm_, i < max_tmp_idx)
|
|
&& IMPLICATION(is_ymm_, i < 2));
|
|
return Vmm(n_blk_step_ + i);
|
|
}
|
|
|
|
void init_tail_mask(const int columns_tail, const bool use_int4_mask);
|
|
bool preload_int4(const Xmm &xmm_in, const int i, const int columns_tail,
|
|
const bool is_tail, const dim_t offset);
|
|
void copy_row_x_col(int nrows, int ncolumns);
|
|
void compute_K_loop(bool is_N_tail, int curr_K_tail, bool is_first_K_iter,
|
|
bool is_last_K_iter);
|
|
void compute_N_loop(
|
|
int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter);
|
|
|
|
inline void dot_product(Vmm v1, Vmm v2, Vmm v3) {
|
|
if (!avx512_core_dot_product_)
|
|
vpdpbusd(v1, v2, v3, get_encoding());
|
|
else {
|
|
vpmaddubsw(vmm_dot_product_temp, v2, v3);
|
|
vpmaddwd(
|
|
vmm_dot_product_temp, vmm_dot_product_temp, vmm_ones_words);
|
|
vpaddd(v1, v1, vmm_dot_product_temp);
|
|
}
|
|
}
|
|
inline bool valid_to_load_next(int next_row_idx, int num_rows) {
|
|
const bool dynamic_tail = is_dynamic_N_ && num_rows < n_blk_step_;
|
|
return next_row_idx < num_rows || dynamic_tail;
|
|
}
|
|
|
|
/**
|
|
* Loads zero point value and broadcasts it over Vmm register.
|
|
* Supported data types: s4/u4/s8/u8/s32.
|
|
*
|
|
* @param n N-dimension local index.
|
|
* @param is_tail Bool flag indicating if tail is processing.
|
|
*/
|
|
void load_zero_point(int n, bool is_tail) {
|
|
if (!conf_->is_wei_zp_per_n) return;
|
|
const auto zp_dt = conf_->wei_zp_dt;
|
|
const auto zp_dt_sz = types::data_type_size(zp_dt);
|
|
const auto elems_per_byte
|
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
|
const auto offset = n * zp_dt_sz / elems_per_byte;
|
|
const auto addr = maybe_EVEX_compress_addr(reg_zp_ptr, offset);
|
|
|
|
const bool is_odd_index = n % elems_per_byte == 1;
|
|
const auto tmp_xmm = Xmm(vmm_zp_b_val.getIdx());
|
|
MAYBE_UNUSED(tmp_xmm);
|
|
MAYBE_UNUSED(is_odd_index);
|
|
|
|
const bool need_upconvert = one_of(zp_dt, data_type::s8, data_type::u8,
|
|
data_type::s4, data_type::u4);
|
|
if (need_upconvert) {
|
|
uni_vpinsrb(tmp_xmm, tmp_xmm, addr, 0);
|
|
if (one_of(zp_dt, data_type::s8, data_type::s4))
|
|
uni_vpmovsxbd(tmp_xmm, tmp_xmm);
|
|
else
|
|
uni_vpmovzxbd(tmp_xmm, tmp_xmm);
|
|
|
|
// 4-bit integer must be shifted left depending
|
|
// which element of 2 is required
|
|
if (one_of(zp_dt, data_type::s4, data_type::u4))
|
|
uni_vpslld(tmp_xmm, tmp_xmm, 28 - is_odd_index * 4);
|
|
// Then shift back to the right on 28 bits
|
|
if (zp_dt == data_type::u4) vpsrld(tmp_xmm, tmp_xmm, 28);
|
|
if (zp_dt == data_type::s4) vpsrad(tmp_xmm, tmp_xmm, 28);
|
|
}
|
|
const auto &op = need_upconvert
|
|
? static_cast<const Xbyak::Operand &>(tmp_xmm)
|
|
: static_cast<const Xbyak::Operand &>(addr);
|
|
const auto masked_vmm = maybe_mask(vmm_zp_b_val, is_tail);
|
|
uni_vpbroadcastd(masked_vmm, op);
|
|
}
|
|
|
|
/**
|
|
* Loads scales and broadcasts it over Vmm register.
|
|
* Supported data types: f32, bf16, f16.
|
|
*
|
|
* @param n N-dimension local index.
|
|
* @param is_tail Bool flag indicating if tail is processing.
|
|
*/
|
|
void load_scales(int n, bool is_tail) {
|
|
if (!conf_->is_wei_scale_per_n || !conf_->apply_scales_in_buffer_b)
|
|
return;
|
|
|
|
const auto &scales_dt = conf_->wei_scales_dt;
|
|
const auto &scales_dt_sz = conf_->wei_scales_dt_sz;
|
|
const auto offset = n * scales_dt_sz;
|
|
const auto masked_vmm = maybe_mask(vmm_wei_scales, is_tail);
|
|
const auto addr = EVEX_compress_addr(
|
|
reg_wei_scales, offset, scales_dt == data_type::f16);
|
|
vpxord(vmm_wei_scales, vmm_wei_scales, vmm_wei_scales);
|
|
switch (scales_dt) {
|
|
case data_type::f32: uni_vbroadcastss(vmm_wei_scales, addr); break;
|
|
case data_type::bf16:
|
|
vpbroadcastw(masked_vmm, addr);
|
|
uni_vpslld(vmm_wei_scales, vmm_wei_scales, 16);
|
|
break;
|
|
case data_type::f16: vcvtph2psx(vmm_wei_scales, addr); break;
|
|
default: assert(!"unsupported wei_scales data type");
|
|
}
|
|
}
|
|
|
|
/** Stores half of the block using mask for the case when vnni_granularity == 2 */
|
|
void store_half_block(const Zmm &r, const Xbyak::Address &store_addr) {
|
|
Label even_k, end_permute;
|
|
test(reg_K_start, 1);
|
|
jz(even_k, T_NEAR);
|
|
// Shift left by 16 bytes to store odd indices
|
|
uni_vpslld(r, r, 16);
|
|
vmovdqu16(store_addr | kAAAA, r);
|
|
jmp(end_permute);
|
|
L(even_k);
|
|
// Store even indices, odd set to zero.
|
|
vmovdqu16(store_addr, r);
|
|
L(end_permute);
|
|
}
|
|
|
|
void generate() override;
|
|
};
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::init_tail_mask(
|
|
const int columns_tail, const bool use_int4_mask) {
|
|
assert(IMPLICATION(use_int4_mask, is_src_int4_));
|
|
if (columns_tail > 0) {
|
|
const int dt_step = req_cvtps2xf16_ || use_fp16_instructions_
|
|
|| use_bf16_instructions_
|
|
? 1
|
|
: typesize_;
|
|
const auto tail_mask = use_int4_mask
|
|
? size_t(((size_t)1 << div_up(dt_step * columns_tail, 2)) - 1)
|
|
: size_t(((size_t)1 << dt_step * columns_tail) - 1);
|
|
if (req_cvtps2xf16_)
|
|
kmovw(kTail, tail_mask);
|
|
else
|
|
kmovq(kTail, tail_mask);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Handles special case when loading INT4 data with unaligned src_stride
|
|
*
|
|
* When processing INT4 data and (i * src_stride_) % 2 != 0, we need to perform additional operations
|
|
* to handle the unaligned half-byte boundaries:
|
|
*
|
|
* - For small loads (< 8 bytes or tail cases): We perform a simple right shift by 4 bits
|
|
* to eliminate the unnecessary half-byte at the front
|
|
*
|
|
* - For large loads (8 bytes): We need two registers to correctly reconstruct the data:
|
|
* 1. Shift the first register right by 4 bits to remove leading half-byte
|
|
* 2. Shift the second register left by 4 bits to preserve trailing half-byte
|
|
* 3. Combine both registers with logical OR to get the correctly aligned result
|
|
*
|
|
* @param xmm_in The target XMM register to store the aligned data
|
|
* @param i Current row index
|
|
* @param columns_tail Number of remaining columns in the tail case
|
|
* @param is_tail Flag indicating if processing a tail section
|
|
* @param offset Memory offset for loading from source
|
|
* @return True if INT4 preloading was performed, false otherwise
|
|
*/
|
|
template <typename Vmm>
|
|
bool jit_brgemm_matmul_copy_b_transposed_t<Vmm>::preload_int4(const Xmm &xmm_in,
|
|
const int i, const int columns_tail, const bool is_tail,
|
|
const dim_t offset) {
|
|
const auto addr = EVEX_compress_addr(reg_src, offset);
|
|
const bool need_preload_int4 = is_src_int4_ && (i * src_stride_) % 2 != 0;
|
|
const auto max_shift_sz = 8;
|
|
if (need_preload_int4) {
|
|
const auto load_sz = is_tail ? div_up(columns_tail, 2)
|
|
: req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ / 2
|
|
: k_blk_step_ / 2;
|
|
assert(load_sz <= max_shift_sz);
|
|
if (load_sz < max_shift_sz || is_tail) {
|
|
load_bytes(xmm_in, addr, load_sz);
|
|
vpsrlq(xmm_in, xmm_in, 4);
|
|
} else {
|
|
const auto xmm_tmp = Xmm(tmp_vmm(3).getIdx());
|
|
load_bytes(xmm_in, addr, load_sz);
|
|
load_bytes(
|
|
xmm_tmp, EVEX_compress_addr(reg_src, offset + 1), load_sz);
|
|
vpsrlq(xmm_in, xmm_in, 4);
|
|
vpsllq(xmm_tmp, xmm_tmp, 4);
|
|
vpord(xmm_in, xmm_in, xmm_tmp);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// The case when the kernel is grouped over K and need to load odd or even columns
|
|
const auto preload_for_k_1_blk
|
|
= is_src_int4_ && is_wei_grouped_over_k_ && columns_tail == 1;
|
|
if (preload_for_k_1_blk) {
|
|
// Unconditionally load 1 byte, then shift if odd index
|
|
load_bytes(xmm_in, addr, 1);
|
|
Label load_done, even_k;
|
|
test(reg_K_start, 1);
|
|
jz(even_k, T_NEAR);
|
|
vpsrlq(xmm_in, xmm_in, 4);
|
|
jmp(load_done);
|
|
L(even_k);
|
|
vpsllq(xmm_in, xmm_in, 4);
|
|
vpsrlq(xmm_in, xmm_in, 4);
|
|
L(load_done);
|
|
return true;
|
|
}
|
|
return false;
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::copy_row_x_col(
|
|
int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= n_blk_step_ && ncolumns >= 0
|
|
&& ncolumns <= k_blk_step_);
|
|
if (!nrows) return;
|
|
|
|
const auto cur_k_blk_step
|
|
= req_cvtps2xf16_ ? req_cvt_bf16_k_blk_step_ : k_blk_step_;
|
|
|
|
const int columns_tail = ncolumns % cur_k_blk_step;
|
|
|
|
init_tail_mask(columns_tail, false);
|
|
|
|
auto load2bf16 = [this, nrows, columns_tail, ncolumns](int i) {
|
|
auto src_reg = src_vmm(i);
|
|
auto src_reg_next = tmp_vmm(2);
|
|
|
|
Label load_done;
|
|
if (is_dynamic_N_ && nrows < n_blk_step_) {
|
|
Label general_load;
|
|
cmp(reg_N_iters, i);
|
|
jg(general_load); // i < dynamic nrows -> general load
|
|
|
|
// i >= dynamic nrows -> zero out values in src_reg
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
jmp(load_done);
|
|
|
|
L(general_load);
|
|
} else if (i >= nrows) {
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
return;
|
|
}
|
|
|
|
// check if k_tail exists and it's in the first zmm
|
|
const auto is_tail
|
|
= columns_tail > 0 && ncolumns < req_cvt_bf16_k_blk_step_;
|
|
auto src_reg_masked = maybe_mask(src_reg, is_tail);
|
|
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
|
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
|
if (is_bf32_)
|
|
vmovups(src_reg_masked, addr);
|
|
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
|
const auto xmm_preload = Xmm(src_reg.getIdx());
|
|
MAYBE_UNUSED(xmm_preload);
|
|
const bool preloaded_int4 = preload_int4(
|
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
|
const auto &src_op = preloaded_int4
|
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
|
: static_cast<const Xbyak::Operand &>(addr);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
|
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
|
load_zero_point(i, is_tail);
|
|
load_scales(i, is_tail);
|
|
decompress_reg(src_reg_masked, vmm_zp_b_val, vmm_wei_scales,
|
|
conf_->orig_wei_dt);
|
|
} else
|
|
assert(!"Unsupported data type in loading");
|
|
|
|
if (ncolumns <= req_cvt_bf16_k_blk_step_) {
|
|
vpxord(src_reg_next, src_reg_next, src_reg_next);
|
|
} else {
|
|
const auto is_tail = columns_tail > 0;
|
|
auto src_next_masked = maybe_mask(src_reg_next, is_tail);
|
|
const auto next_src_offset
|
|
= (i * src_stride_ + req_cvt_bf16_k_blk_step_ * typesize_)
|
|
/ src_elems_per_byte_;
|
|
const auto next_addr = EVEX_compress_addr(reg_src, next_src_offset);
|
|
if (is_bf32_)
|
|
vmovups(src_next_masked, next_addr);
|
|
else if (is_bf16_with_int_wei_ || conf_->is_f16_with_int_wei) {
|
|
const auto xmm_preload = Xmm(src_reg_next.getIdx());
|
|
MAYBE_UNUSED(xmm_preload);
|
|
const bool preloaded_int4 = preload_int4(
|
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
|
const auto &src_op = preloaded_int4
|
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
|
: static_cast<const Xbyak::Operand &>(next_addr);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
|
load_value(src_reg_next, src_op, vmm_permd, conf_->orig_wei_dt,
|
|
is_tail);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
|
load_zero_point(i, is_tail);
|
|
load_scales(i, is_tail);
|
|
decompress_reg(src_next_masked, vmm_zp_b_val, vmm_wei_scales,
|
|
conf_->orig_wei_dt);
|
|
} else
|
|
assert(!"Unsupported data type in loading");
|
|
}
|
|
downconvert_to_dst_dt(src_reg, src_reg_next, conf_->wei_dt);
|
|
L(load_done);
|
|
};
|
|
|
|
auto load = [this, nrows, columns_tail](int i, int base_idx) {
|
|
Label load_done;
|
|
|
|
auto src_reg = src_vmm(i);
|
|
if (is_dynamic_N_ && nrows < n_blk_step_) {
|
|
Label general_load;
|
|
cmp(reg_N_iters, i);
|
|
jg(general_load); // i < dynamic nrows -> general load
|
|
|
|
// i >= dynamic nrows -> zero out values in src_reg
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
jmp(load_done);
|
|
|
|
L(general_load);
|
|
} else if (i >= nrows) {
|
|
vpxord(src_reg, src_reg, src_reg);
|
|
return;
|
|
}
|
|
|
|
const auto is_tail = columns_tail > 0;
|
|
const auto src_offset = (i * src_stride_) / src_elems_per_byte_;
|
|
const auto addr = EVEX_compress_addr(reg_src, src_offset);
|
|
auto src_masked_reg = maybe_mask(src_reg, is_tail);
|
|
if ((conf_->is_f16_with_int_wei || conf_->is_f32_with_int_wei)
|
|
&& conf_->wei_dt == data_type::f32) {
|
|
const auto xmm_preload = Xmm(src_reg.getIdx());
|
|
MAYBE_UNUSED(xmm_preload);
|
|
const bool preloaded_int4 = preload_int4(
|
|
xmm_preload, i, columns_tail, is_tail, src_offset);
|
|
const auto &src_op = preloaded_int4
|
|
? static_cast<const Xbyak::Operand &>(xmm_preload)
|
|
: static_cast<const Xbyak::Operand &>(addr);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, true);
|
|
load_value(src_reg, src_op, vmm_permd, conf_->orig_wei_dt, is_tail);
|
|
if (is_src_int4_) init_tail_mask(columns_tail, false);
|
|
load_zero_point(i, is_tail);
|
|
load_scales(i, is_tail);
|
|
decompress_reg(src_masked_reg, vmm_zp_b_val, vmm_wei_scales,
|
|
conf_->orig_wei_dt);
|
|
} else if (use_fp16_instructions_) {
|
|
if (conf_->isa == avx512_core_fp16) {
|
|
vcvtph2psx(src_masked_reg, addr);
|
|
} else {
|
|
vcvtph2ps(src_masked_reg, addr);
|
|
}
|
|
} else if (use_bf16_instructions_) {
|
|
// Upconvert: load 16 bits and move them 16 bits left.
|
|
uni_vpmovzxwd(src_masked_reg, addr);
|
|
uni_vpslld(src_masked_reg, src_masked_reg, 16);
|
|
} else {
|
|
vmovdqu8(src_masked_reg, addr);
|
|
}
|
|
L(load_done);
|
|
};
|
|
|
|
auto store = [this, columns_tail, ncolumns, cur_k_blk_step](Zmm r, int i) {
|
|
auto addr = EVEX_compress_addr(reg_tr_src, i * tr_src_stride_);
|
|
if (is_wei_grouped_over_k_) {
|
|
const bool is_tail = columns_tail > 0 && ncolumns < cur_k_blk_step;
|
|
if (is_tail && i >= columns_tail) return;
|
|
if (vnni_granularity_ == 2 && ncolumns == 1)
|
|
store_half_block(r, addr);
|
|
else
|
|
vmovups(addr, r);
|
|
} else {
|
|
vmovups(addr, r);
|
|
}
|
|
};
|
|
|
|
auto transpose16x8 = [&](int base_idx) {
|
|
assert(base_idx == 0 || base_idx == 8);
|
|
|
|
// swap 1
|
|
if (req_cvtps2xf16_) {
|
|
for (int i = 0; i < 4; i++) {
|
|
const int src_idx0 = base_idx + i * 2;
|
|
const int src_idx1 = src_idx0 + 1;
|
|
|
|
if (base_idx == 0 && i == 0) {
|
|
load2bf16(src_idx0);
|
|
load2bf16(src_idx1);
|
|
}
|
|
|
|
const int next_src_idx0 = src_idx0 + 2;
|
|
const int next_src_idx1 = src_idx1 + 2;
|
|
|
|
const bool load_next = base_idx == 0 || i < 3;
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp1 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src1 = src_vmm(src_idx1);
|
|
|
|
if (valid_to_load_next(next_src_idx0, nrows) && load_next)
|
|
load2bf16(next_src_idx0);
|
|
valignd(tmp0, src0, src0, 0x1);
|
|
|
|
if (valid_to_load_next(next_src_idx1, nrows) && load_next)
|
|
load2bf16(next_src_idx1);
|
|
valignd(tmp1, src1, src1, 0xf);
|
|
|
|
vmovaps(src0 | kAAAA, tmp1);
|
|
vmovaps(src1 | k5555, tmp0);
|
|
}
|
|
} else {
|
|
for (int i = 0; i < 4; i++) {
|
|
const int src_idx0 = base_idx + i * 2;
|
|
const int src_idx1 = src_idx0 + 1;
|
|
|
|
const int next_src_idx0 = src_idx0 + 2;
|
|
const int next_src_idx1 = src_idx1 + 2;
|
|
const bool load_next = base_idx == 0 || i < 3;
|
|
|
|
if (base_idx == 0 && i == 0) {
|
|
load(src_idx0, base_idx);
|
|
load(src_idx1, base_idx);
|
|
}
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp1 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src1 = src_vmm(src_idx1);
|
|
|
|
if (valid_to_load_next(next_src_idx0, nrows) && load_next)
|
|
load(next_src_idx0, base_idx);
|
|
valignd(tmp0, src0, src0, 0x1);
|
|
|
|
if (valid_to_load_next(next_src_idx1, nrows) && load_next)
|
|
load(next_src_idx1, base_idx);
|
|
valignd(tmp1, src1, src1, 0xf);
|
|
|
|
vmovaps(src0 | kAAAA, tmp1);
|
|
vmovaps(src1 | k5555, tmp0);
|
|
}
|
|
}
|
|
// swap 2
|
|
for (int i = 0; i < 4; i++) {
|
|
const int select_half = (i < 2) ? 0 : 2;
|
|
const int src_idx0 = base_idx + i + select_half + 0;
|
|
const int src_idx2 = src_idx0 + 2;
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp1 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src2 = src_vmm(src_idx2);
|
|
|
|
valignd(tmp0, src0, src0, 0x2);
|
|
valignd(tmp1, src2, src2, 0xe);
|
|
vmovaps(src2 | k3333, tmp0);
|
|
vmovaps(src0 | kCCCC, tmp1);
|
|
}
|
|
// swap 4
|
|
for (int i = 0; i < 4; i++) {
|
|
const int src_idx0 = base_idx + i;
|
|
const int src_idx4 = src_idx0 + 4;
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src4 = src_vmm(src_idx4);
|
|
|
|
vmovaps(tmp0, src0);
|
|
vshuff32x4(src0 | kF0F0, src4, src4, 0xb1);
|
|
vshuff32x4(src4 | k0F0F, tmp0, tmp0, 0xb1);
|
|
}
|
|
};
|
|
|
|
auto fixup16x16 = [&]() {
|
|
for (int i = 0; i < 8; i++) {
|
|
const auto tmp = tmp_vmm(0);
|
|
const auto src0 = src_vmm(i);
|
|
const auto src8 = src_vmm(8 + i);
|
|
vshuff64x2(tmp, src0, src8, 0x44);
|
|
if (do_compute_compensation_)
|
|
dot_product(vmm_comp_acc, vmm_comp_mul, tmp);
|
|
store(tmp, i);
|
|
}
|
|
|
|
for (int i = 0; i < 8; i++) {
|
|
const auto tmp = tmp_vmm(0);
|
|
const auto src0 = src_vmm(i);
|
|
const auto src8 = src_vmm(8 + i);
|
|
vshuff64x2(tmp, src0, src8, 0xee);
|
|
if (do_compute_compensation_)
|
|
dot_product(vmm_comp_acc, vmm_comp_mul, tmp);
|
|
store(tmp, 8 + i);
|
|
}
|
|
};
|
|
|
|
transpose16x8(0);
|
|
transpose16x8(8);
|
|
fixup16x16();
|
|
}
|
|
|
|
template <>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Ymm>::copy_row_x_col(
|
|
int nrows, int ncolumns) {
|
|
assert(nrows >= 0 && nrows <= n_blk_step_ && ncolumns >= 0
|
|
&& ncolumns <= k_blk_step_);
|
|
if (!nrows) return;
|
|
|
|
const int columns_tail = ncolumns % k_blk_step_;
|
|
|
|
auto load = [this, nrows, columns_tail](int i) {
|
|
auto vmm_src = src_vmm(i);
|
|
|
|
Label load_done;
|
|
if (is_dynamic_N_ && nrows < n_blk_step_) {
|
|
Label general_load;
|
|
cmp(reg_N_iters, i);
|
|
jg(general_load); // i < dynamic nrows -> general load
|
|
|
|
// i >= dynamic nrows -> zero out values in src_reg
|
|
vpxord(vmm_src, vmm_src, vmm_src);
|
|
jmp(load_done);
|
|
|
|
L(general_load);
|
|
} else if (i >= nrows) {
|
|
uni_vpxor(vmm_src, vmm_src, vmm_src);
|
|
return;
|
|
}
|
|
|
|
if (columns_tail > 0) {
|
|
load_bytes(vmm_src, reg_src, i * src_stride_,
|
|
columns_tail * typesize_);
|
|
if (use_fp16_instructions_) {
|
|
// For f32:f16 case need to convert raw bytes after `load_bytes`
|
|
// into f32 values.
|
|
vcvtph2ps(vmm_src, Xmm(vmm_src.getIdx()));
|
|
} else if (use_bf16_instructions_) {
|
|
// Upconvert: move loaded 16 bits left.
|
|
uni_vpmovzxwd(vmm_src, vmm_src);
|
|
uni_vpslld(vmm_src, vmm_src, 16);
|
|
}
|
|
} else {
|
|
if (use_fp16_instructions_) {
|
|
// For non-tailed case can use the convert instruction directly.
|
|
vcvtph2ps(vmm_src, ptr[reg_src + i * src_stride_]);
|
|
} else if (use_bf16_instructions_) {
|
|
// Upconvert: load 16 bits and move them 16 bits left.
|
|
uni_vpmovzxwd(vmm_src, ptr[reg_src + i * src_stride_]);
|
|
uni_vpslld(vmm_src, vmm_src, 16);
|
|
} else {
|
|
uni_vmovups(vmm_src, ptr[reg_src + i * src_stride_]);
|
|
}
|
|
}
|
|
|
|
L(load_done);
|
|
};
|
|
|
|
// swap 1
|
|
for (int i = 0; i < 4; ++i) {
|
|
const int src_idx0 = i * 2;
|
|
const int src_idx1 = src_idx0 + 1;
|
|
|
|
const int next_src_idx0 = src_idx0 + 2;
|
|
const int next_src_idx1 = src_idx1 + 2;
|
|
const bool load_next = i < 3;
|
|
|
|
if (i == 0) {
|
|
load(src_idx0);
|
|
load(src_idx1);
|
|
}
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp1 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src1 = src_vmm(src_idx1);
|
|
|
|
if (valid_to_load_next(next_src_idx0, nrows) && load_next) {
|
|
load(next_src_idx0);
|
|
}
|
|
vperm2i128(tmp0, src0, src0, 0x1);
|
|
vpalignr(tmp0, tmp0, src0, 0x4);
|
|
|
|
if (valid_to_load_next(next_src_idx1, nrows) && load_next) {
|
|
load(next_src_idx1);
|
|
}
|
|
vperm2i128(tmp1, src1, src1, 0x1);
|
|
vpalignr(tmp1, src1, tmp1, 0xC);
|
|
|
|
vpblendd(src0, src0, tmp1, 0xAA);
|
|
vpblendd(src1, src1, tmp0, 0x55);
|
|
}
|
|
// swap 2
|
|
for (int i = 0; i < 4; ++i) {
|
|
const int select_half = (i < 2) ? 0 : 2;
|
|
const int src_idx0 = i + select_half;
|
|
const int src_idx2 = src_idx0 + 2;
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp1 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src2 = src_vmm(src_idx2);
|
|
|
|
vperm2i128(tmp0, src0, src0, 0x1);
|
|
vpalignr(tmp0, tmp0, src0, 0x8);
|
|
|
|
vperm2i128(tmp1, src2, src2, 0x1);
|
|
vpalignr(tmp1, src2, tmp1, 0x8);
|
|
|
|
vpblendd(src2, src2, tmp0, 0x33);
|
|
vpblendd(src0, src0, tmp1, 0xCC);
|
|
}
|
|
// swap 4
|
|
for (int i = 0; i < 4; ++i) {
|
|
const int src_idx0 = i;
|
|
const int src_idx4 = src_idx0 + 4;
|
|
|
|
const auto tmp0 = tmp_vmm(0);
|
|
const auto tmp4 = tmp_vmm(1);
|
|
const auto src0 = src_vmm(src_idx0);
|
|
const auto src4 = src_vmm(src_idx4);
|
|
|
|
vperm2i128(tmp4, src4, src4, 0x01);
|
|
vperm2i128(tmp0, src0, src0, 0x01);
|
|
vpblendd(src0, src0, tmp4, 0xF0);
|
|
vpblendd(src4, src4, tmp0, 0x0F);
|
|
}
|
|
// swap 8
|
|
for (int i = 0; i < 8; i++) {
|
|
const auto src0 = src_vmm(i);
|
|
if (do_compute_compensation_)
|
|
dot_product(vmm_comp_acc, vmm_comp_mul, src0);
|
|
uni_vmovups(ptr[reg_tr_src + i * tr_src_stride_], src0);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::compute_K_loop(bool is_N_tail,
|
|
int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
|
|
MAYBE_UNUSED(is_first_K_iter);
|
|
MAYBE_UNUSED(is_last_K_iter);
|
|
const int N_chunk_tail = is_dynamic_N_
|
|
? 1 /* just to force tail processing */
|
|
: conf_->N % n_blk_step_;
|
|
const int nrows = is_N_tail ? N_chunk_tail : n_blk_step_;
|
|
if (do_compute_compensation_)
|
|
uni_vpxor(vmm_comp_acc, vmm_comp_acc, vmm_comp_acc);
|
|
|
|
Label K_loop, K_loop_tail_or_done;
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
|
|
mov(reg_src, reg_src_base);
|
|
mov(reg_tr_src, reg_tr_src_base);
|
|
|
|
if (curr_K_tail > 0) {
|
|
cmp(reg_K_iters, k_blk_step_);
|
|
jl(K_loop_tail_or_done, T_NEAR);
|
|
}
|
|
|
|
L(K_loop);
|
|
copy_row_x_col(nrows, k_blk_step_);
|
|
add(reg_src, (k_blk_step_ * typesize_) / src_elems_per_byte_);
|
|
add(reg_tr_src, k_blk_step_ / vnni_granularity_ * tr_src_stride_);
|
|
|
|
sub(reg_K_iters, k_blk_step_);
|
|
cmp(reg_K_iters, k_blk_step_);
|
|
jge(K_loop, T_NEAR);
|
|
|
|
L(K_loop_tail_or_done);
|
|
|
|
if (curr_K_tail > 0) copy_row_x_col(nrows, curr_K_tail);
|
|
|
|
if (req_s8s8_comp_) {
|
|
const auto addr = ptr[reg_comp_ptr];
|
|
if (!is_first_K_iter)
|
|
uni_vpaddd(vmm_s8s8_comp_acc, vmm_comp_acc, addr);
|
|
else
|
|
uni_vmovups(vmm_s8s8_comp_acc, vmm_comp_acc);
|
|
|
|
if (is_last_K_iter) {
|
|
// multiply by 128
|
|
uni_vpslld(vmm_s8s8_comp_acc, vmm_s8s8_comp_acc, 7);
|
|
// change sign
|
|
uni_vpandnd(vmm_s8s8_comp_acc, vmm_s8s8_comp_acc, vmm_all_bits_1);
|
|
uni_vpaddd(vmm_s8s8_comp_acc, vmm_s8s8_comp_acc, vmm_one_s32);
|
|
}
|
|
uni_vmovups(addr, vmm_s8s8_comp_acc);
|
|
}
|
|
if (req_zp_comp_) {
|
|
const auto addr = ptr[reg_zp_comp_ptr];
|
|
if (!is_first_K_iter) vpaddd(vmm_comp_acc, vmm_comp_acc, addr);
|
|
if (is_last_K_iter)
|
|
uni_vpmulld(vmm_comp_acc, vmm_comp_acc, vmm_zp_a_neg_val);
|
|
uni_vmovups(addr, vmm_comp_acc);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::compute_N_loop(
|
|
int curr_K_tail, bool is_first_K_iter, bool is_last_K_iter) {
|
|
const bool generate_N_tail = is_dynamic_N_ || (conf_->N % n_blk_step_ > 0);
|
|
|
|
Label N_loop, N_loop_tail_or_done;
|
|
if (generate_N_tail) {
|
|
cmp(reg_N_iters, n_blk_step_);
|
|
jl(N_loop_tail_or_done, T_NEAR);
|
|
}
|
|
if (conf_->LDB2 > 0) {
|
|
mov(regq_tmp, 0);
|
|
mov(ptr[rsp + ldb_step_idx_offs], regq_tmp);
|
|
}
|
|
|
|
L(N_loop);
|
|
compute_K_loop(false, curr_K_tail, is_first_K_iter, is_last_K_iter);
|
|
|
|
add(reg_src_base, (n_blk_step_ * src_stride_) / src_elems_per_byte_);
|
|
if (conf_->LDB2 > 0) {
|
|
Label small_increment, ldb_off_done;
|
|
mov(regq_tmp, ptr[rsp + ldb_step_idx_offs]);
|
|
add(regq_tmp, n_blk_step_);
|
|
|
|
cmp(regq_tmp, conf_->LDB);
|
|
jne(small_increment, T_NEAR);
|
|
|
|
add(reg_tr_src_base,
|
|
-conf_->LDB * vnni_granularity_ * tr_typesize_
|
|
+ n_blk_step_ * vnni_granularity_ * tr_typesize_
|
|
+ conf_->LDB2 * tr_typesize_);
|
|
mov(regq_tmp, 0);
|
|
mov(ptr[rsp + ldb_step_idx_offs], regq_tmp);
|
|
jmp(ldb_off_done, T_NEAR);
|
|
L(small_increment);
|
|
add(reg_tr_src_base, n_blk_step_ * vnni_granularity_ * tr_typesize_);
|
|
mov(ptr[rsp + ldb_step_idx_offs], regq_tmp);
|
|
L(ldb_off_done);
|
|
} else {
|
|
add(reg_tr_src_base, n_blk_step_ * vnni_granularity_ * tr_typesize_);
|
|
}
|
|
|
|
if (conf_->is_wei_scale_per_n) {
|
|
const auto &scales_dt_sz = conf_->wei_scales_dt_sz;
|
|
const auto offset = n_blk_step_ * scales_dt_sz;
|
|
add(reg_wei_scales, offset);
|
|
}
|
|
|
|
if (conf_->is_wei_zp_per_n) {
|
|
const auto &zp_dt = conf_->wei_zp_dt;
|
|
const auto zp_dt_sz = types::data_type_size(zp_dt);
|
|
const auto elems_per_byte
|
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
|
const auto offset = n_blk_step_ * zp_dt_sz / elems_per_byte;
|
|
add(reg_zp_ptr, offset);
|
|
}
|
|
|
|
if (req_zp_comp_) add(reg_zp_comp_ptr, comp_shift_);
|
|
if (req_s8s8_comp_) add(reg_comp_ptr, comp_shift_);
|
|
|
|
sub(reg_N_iters, n_blk_step_);
|
|
cmp(reg_N_iters, n_blk_step_);
|
|
jge(N_loop, T_NEAR);
|
|
|
|
L(N_loop_tail_or_done);
|
|
if (generate_N_tail) {
|
|
Label N_loop_done;
|
|
cmp(reg_N_iters, 0);
|
|
jle(N_loop_done, T_NEAR);
|
|
|
|
compute_K_loop(true, curr_K_tail, is_first_K_iter, is_last_K_iter);
|
|
L(N_loop_done);
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_transposed_t<Vmm>::generate() {
|
|
preamble();
|
|
sub(rsp, stack_space_needed);
|
|
|
|
if (avx512_core_dot_product_) {
|
|
mov(regq_tmp.cvt16(), 1);
|
|
vpbroadcastw(vmm_ones_words, regq_tmp.cvt16());
|
|
}
|
|
|
|
mov(reg_src_base, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src_base, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
mov(reg_N_iters, ptr[param1 + GET_OFF(current_N_blk)]);
|
|
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
|
mov(reg_zp_ptr, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
|
mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
|
|
|
|
if (!is_ymm_) {
|
|
// 64-bit mask is also used when is_wei_[zp\scales]_per_k
|
|
kmovq(kAAAA, 0xAAAAAAAAAAAAAAAA);
|
|
kmovq(k5555, 0x5555555555555555);
|
|
|
|
kmovw(k3333, 0x3333);
|
|
kmovw(kCCCC, 0xcccc);
|
|
kmovw(k0F0F, 0x0f0f);
|
|
kmovw(kF0F0, 0xf0f0);
|
|
}
|
|
if (is_src_int4_ && is_superset(conf_->isa, avx512_core)) {
|
|
alignas(64) static constexpr const uint32_t int4_permute[16]
|
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
|
mov(regq_tmp, reinterpret_cast<size_t>(int4_permute));
|
|
vmovdqa32(vmm_permd, ptr[regq_tmp]);
|
|
}
|
|
|
|
load_common_zp_value(vmm_zp_b_val, reg_zp_ptr);
|
|
load_common_scale_value(vmm_wei_scales, reg_wei_scales);
|
|
|
|
const dim_t N_chunk_elems = conf_->N_chunk_elems;
|
|
assert(N_chunk_elems % n_blk_step_ == 0 || N_chunk_elems == conf_->N);
|
|
UNUSED(N_chunk_elems);
|
|
|
|
const auto &k_blk = conf_->K_blk;
|
|
const auto K_blk_tail = nstl::min(conf_->K, k_blk) % k_blk_step_;
|
|
const auto K_tail_tail = (conf_->K % k_blk) % k_blk_step_;
|
|
|
|
const auto grouped_k = is_wei_grouped_over_k_
|
|
? (conf_->is_wei_zp_per_k ? conf_->wei_zp_k_gsize
|
|
: conf_->wei_scales_k_gsize)
|
|
: 0;
|
|
|
|
auto compute_body = [&](bool is_first_K_iter, bool is_last_K_iter) {
|
|
if (is_last_K_iter) {
|
|
if (req_s8s8_comp_) {
|
|
mov(imm_addr64, 0xffffffff);
|
|
uni_vpbroadcastd(vmm_all_bits_1, imm_addr64.cvt32());
|
|
mov(imm_addr64, 0x1);
|
|
uni_vpbroadcastd(vmm_one_s32, imm_addr64.cvt32());
|
|
}
|
|
if (req_zp_comp_) {
|
|
mov(reg_zp_a_neg_val_ptr,
|
|
ptr[param1 + GET_OFF(zp_a_neg_value_ptr)]);
|
|
uni_vbroadcastss(vmm_zp_a_neg_val, ptr[reg_zp_a_neg_val_ptr]);
|
|
}
|
|
}
|
|
|
|
if (is_wei_grouped_over_k_ && grouped_k < k_blk_step_) {
|
|
compute_N_loop(grouped_k, is_first_K_iter, is_last_K_iter);
|
|
return;
|
|
}
|
|
|
|
Label compute_body_done;
|
|
if (conf_->K_tail > 0 && K_blk_tail != K_tail_tail) {
|
|
Label not_K_tail;
|
|
cmp(reg_K_iters, k_blk);
|
|
je(not_K_tail, T_NEAR);
|
|
compute_N_loop(K_tail_tail, is_first_K_iter, is_last_K_iter);
|
|
jmp(compute_body_done, T_NEAR);
|
|
|
|
L(not_K_tail);
|
|
}
|
|
|
|
compute_N_loop(K_blk_tail, is_first_K_iter, is_last_K_iter);
|
|
L(compute_body_done);
|
|
};
|
|
|
|
Label done;
|
|
if (do_compute_compensation_) {
|
|
assert(IMPLICATION(req_zp_comp_,
|
|
conf_->src_zp_type == brgemm_broadcast_t::per_tensor));
|
|
|
|
mov(reg_K_start, ptr[param1 + GET_OFF(current_K_start)]);
|
|
if (req_s8s8_comp_)
|
|
mov(reg_comp_ptr, ptr[param1 + GET_OFF(compensation_ptr)]);
|
|
if (req_zp_comp_)
|
|
mov(reg_zp_comp_ptr, ptr[param1 + GET_OFF(zp_a_compensation_ptr)]);
|
|
|
|
mov(regq_tmp, 1);
|
|
uni_vpbroadcastb(vmm_comp_mul, regq_tmp.cvt8());
|
|
|
|
const auto last_K_threshold = rnd_up(conf_->K, k_blk) - k_blk;
|
|
Label not_first, not_first_not_last;
|
|
cmp(reg_K_start, 0);
|
|
jne(not_first, T_NEAR);
|
|
{
|
|
// first K iteration
|
|
Label first_not_last;
|
|
cmp(reg_K_start, last_K_threshold);
|
|
jl(first_not_last, T_NEAR);
|
|
compute_body(true, true);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(first_not_last);
|
|
compute_body(true, false);
|
|
jmp(done, T_NEAR);
|
|
}
|
|
|
|
L(not_first);
|
|
cmp(reg_K_start, last_K_threshold);
|
|
jl(not_first_not_last, T_NEAR);
|
|
|
|
compute_body(false, true);
|
|
jmp(done, T_NEAR);
|
|
L(not_first_not_last);
|
|
}
|
|
|
|
compute_body(false, false);
|
|
L(done);
|
|
|
|
add(rsp, stack_space_needed);
|
|
postamble();
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_b_transposed_t<Zmm>;
|
|
template struct jit_brgemm_matmul_copy_b_transposed_t<Ymm>;
|
|
|
|
template <typename Vmm>
|
|
struct jit_brgemm_matmul_copy_b_cvt_bf16_t
|
|
: public jit_brgemm_matmul_copy_b_common_t {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_brgemm_matmul_copy_b_cvt_bf16_t)
|
|
|
|
jit_brgemm_matmul_copy_b_cvt_bf16_t(const brgemm_matmul_conf_t *conf)
|
|
: jit_brgemm_matmul_copy_b_common_t(conf)
|
|
, typesize_(conf->b_dt_sz)
|
|
, tr_typesize_(conf->tr_b_dt_sz)
|
|
, wei_scales_typesize_(conf_->wei_scales_dt_sz)
|
|
, is_src_int4_(one_of(conf->orig_wei_dt, data_type::s4, data_type::u4))
|
|
, src_elems_per_byte_(is_src_int4_ ? 2 : 1)
|
|
, src_stride_(
|
|
(conf->LDB * k_blk_step * typesize_) / src_elems_per_byte_)
|
|
, tr_src_stride_(conf_->LDB * k_blk_step * tr_typesize_)
|
|
, req_zp_b_shift_(
|
|
conf_->has_zero_point_b && conf_->with_wei_decompression)
|
|
, req_apply_wei_scales_(conf_->apply_scales_in_buffer_b)
|
|
, reserved_regs_(req_apply_wei_scales_ ? 6
|
|
: req_zp_b_shift_ ? 4
|
|
: is_src_int4_ ? 1
|
|
: 0)
|
|
, is_wei_grouped_over_k_(
|
|
conf_->is_wei_zp_per_k || conf_->is_wei_scale_per_k) {}
|
|
|
|
void operator()(ctx_t *ctx) override { jit_generator_t::operator()(ctx); }
|
|
status_t create_kernel() override {
|
|
return jit_generator_t::create_kernel();
|
|
}
|
|
|
|
private:
|
|
using reg64_t = const Xbyak::Reg64;
|
|
using reg32_t = const Xbyak::Reg32;
|
|
using opmask_t = const Xbyak::Opmask;
|
|
using Vmm_lower_t = typename vreg_traits_t<Vmm>::Vmm_lower_t;
|
|
using zmm = const Xbyak::Zmm;
|
|
using ymm = const Xbyak::Ymm;
|
|
|
|
enum { k_blk_step = 2, n_blk_step = 16 };
|
|
const int typesize_, tr_typesize_, wei_scales_typesize_;
|
|
const bool is_src_int4_;
|
|
const dim_t src_elems_per_byte_, src_stride_, tr_src_stride_;
|
|
const bool req_zp_b_shift_;
|
|
const bool req_apply_wei_scales_;
|
|
const int reserved_regs_;
|
|
const bool is_wei_grouped_over_k_;
|
|
|
|
reg64_t reg_src = rax;
|
|
reg64_t reg_tr_src = rbx;
|
|
|
|
reg64_t reg_K_iters = r8;
|
|
reg64_t reg_N_blk = r9;
|
|
reg64_t reg_wei_scales = r10;
|
|
reg64_t reg_tmp = r11;
|
|
reg32_t regw_tmp = r11d;
|
|
|
|
reg64_t reg_src_back = r12;
|
|
reg64_t reg_tr_src_back = r13;
|
|
|
|
reg64_t reg_wei_zp = r14;
|
|
reg64_t reg_k_start = r15;
|
|
|
|
Vmm vmm_permd = Vmm(0);
|
|
Vmm vmm_zp_b_val0 = Vmm(1);
|
|
Vmm vmm_zp_b_val1 = Vmm(2);
|
|
Vmm vmm_tmp = Vmm(3);
|
|
Vmm vmm_wei_scales0 = Vmm(4);
|
|
Vmm vmm_wei_scales1 = Vmm(5);
|
|
|
|
Vmm get_vmm(const int blk, const int idx) {
|
|
const int max_isa_regs = isa_num_vregs(conf_->isa);
|
|
const int max_unroll = (max_isa_regs - reserved_regs_) / k_blk_step;
|
|
assert(idx >= 0 && idx < k_blk_step && blk >= 0);
|
|
const auto reg_idx
|
|
= max_unroll * ((idx + 1) % k_blk_step) + blk + reserved_regs_;
|
|
assert(reg_idx >= reserved_regs_ && reg_idx < max_isa_regs);
|
|
return Vmm(reg_idx);
|
|
}
|
|
|
|
void init_masks();
|
|
void get_wei_scales(
|
|
const int n, const bool is_n_tail, const bool is_k_tail);
|
|
void get_zero_points(const int n, const bool is_tail, const bool is_k_tail);
|
|
void copy_block(const int nrows, const int ncolumns, bool zeropad);
|
|
|
|
/** Adjust strides for grouped over k weights
|
|
* k_blk_step is const 2. This case handles
|
|
* nrows = 1
|
|
* Move tr_src pointer to the beginning of the block 2x32
|
|
* if the k_start % 2 = 1 is odd.
|
|
**/
|
|
void maybe_update_strides(int nrows) {
|
|
if (is_wei_grouped_over_k_ && nrows < k_blk_step) {
|
|
Label even_k;
|
|
test(reg_k_start, 1);
|
|
jz(even_k, T_NEAR);
|
|
// Shift back to start of the vnni block
|
|
sub(reg_src, typesize_ / src_elems_per_byte_);
|
|
L(even_k);
|
|
}
|
|
}
|
|
|
|
void save_half_block(const int blk_idx, const Xbyak::Address &store_addr) {
|
|
const auto src0 = get_vmm(blk_idx, 0);
|
|
const auto zmm0 = zmm(src0.getIdx());
|
|
//if k % 2 == 1 then save only odd indices
|
|
// otherwise: only even using masks
|
|
Label even_k, end_permute;
|
|
test(reg_k_start, 1);
|
|
jz(even_k, T_NEAR);
|
|
// Odd indices case
|
|
vmovdqu16(store_addr | kAAAA, zmm0);
|
|
jmp(end_permute);
|
|
L(even_k);
|
|
// Clean the whole block before storing
|
|
uni_vpxor(vmm_tmp, vmm_tmp, vmm_tmp);
|
|
vmovdqu16(store_addr, vmm_tmp);
|
|
// Store only even indices
|
|
vmovdqu16(store_addr | k5555 | T_z, zmm0);
|
|
L(end_permute);
|
|
}
|
|
|
|
void generate() override;
|
|
};
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::init_masks() {
|
|
alignas(64) static constexpr const uint32_t bf16_vnni_permute[16]
|
|
= {0, 8, 1, 9, 2, 10, 3, 11, 4, 12, 5, 13, 6, 14, 7, 15};
|
|
|
|
if (is_superset(conf_->isa, avx512_core)) {
|
|
kxnorw(kFFFF, kFFFF, kFFFF); // 1111 1111 1111 1111
|
|
|
|
mov(reg_tmp, reinterpret_cast<size_t>(bf16_vnni_permute));
|
|
vmovdqa32(vmm_permd, ptr[reg_tmp]);
|
|
|
|
// 64-bit mask is also used when is_wei_[zp\scales]_per_k
|
|
mov(reg_tmp, 0xAAAAAAAAAAAAAAAA);
|
|
kmovq(kAAAA, reg_tmp);
|
|
mov(reg_tmp, 0x5555555555555555);
|
|
kmovq(k5555, reg_tmp);
|
|
}
|
|
}
|
|
|
|
/** Loads scales into 2 registers and permutes it.
|
|
* Since groups over K-dimension are handled outside the kernel
|
|
* loading is performed for the same address for both registers.
|
|
*/
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_wei_scales(
|
|
const int n, const bool is_n_tail, const bool is_k_tail) {
|
|
|
|
if (!req_apply_wei_scales_ || !conf_->is_wei_scale_per_n) return;
|
|
|
|
const auto zmm_wei_scales1 = maybe_mask(vmm_wei_scales1, is_n_tail);
|
|
const auto zmm_tmp = maybe_mask(vmm_tmp, is_n_tail);
|
|
const auto base_offset
|
|
= [&](int n_idx) { return n_idx * wei_scales_typesize_; };
|
|
|
|
auto wei_scales_addr
|
|
= maybe_EVEX_compress_addr(reg_wei_scales, base_offset(n));
|
|
|
|
load_scale_value(zmm_tmp, wei_scales_addr, conf_->wei_scales_dt, is_n_tail);
|
|
|
|
uni_vmovups(vmm_wei_scales1, vmm_tmp);
|
|
|
|
vinsertf64x4(vmm_wei_scales0, vmm_tmp, Ymm(vmm_wei_scales1.getIdx()), 1);
|
|
vextractf64x4(Ymm(vmm_tmp.getIdx()), vmm_tmp, 1);
|
|
vinsertf64x4(vmm_wei_scales1, zmm_wei_scales1, Ymm(vmm_tmp.getIdx()), 0);
|
|
vpermd(vmm_wei_scales0, vmm_permd, vmm_wei_scales0);
|
|
vpermd(vmm_wei_scales1, vmm_permd, vmm_wei_scales1);
|
|
}
|
|
|
|
/** Loads zero points into 2 registers and permute it.
|
|
* Since groups over K-dimension are handled outside the kernel
|
|
* loading is performed for the same address for both registers.
|
|
*/
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::get_zero_points(
|
|
const int n, const bool is_n_tail, const bool is_k_tail) {
|
|
if (!conf_->is_wei_zp_per_n) return;
|
|
|
|
const auto zp_dt = conf_->wei_zp_dt;
|
|
|
|
const auto base_offset = [&](int n_idx) {
|
|
const auto zp_dt_sz = types::data_type_size(zp_dt);
|
|
const auto elems_per_byte
|
|
= one_of(zp_dt, data_type::s4, data_type::u4) ? 2 : 1;
|
|
return n_idx * zp_dt_sz / elems_per_byte;
|
|
};
|
|
|
|
const auto addr = maybe_EVEX_compress_addr(reg_wei_zp, base_offset(n));
|
|
load_value(vmm_tmp, addr, vmm_permd, zp_dt, is_n_tail);
|
|
|
|
uni_vmovups(vmm_zp_b_val1, vmm_tmp);
|
|
|
|
const auto zmm_zp_b_val1 = maybe_mask(vmm_zp_b_val1, is_n_tail);
|
|
vinserti64x4(vmm_zp_b_val0, vmm_tmp, Ymm(vmm_zp_b_val1.getIdx()), 1);
|
|
vextracti64x4(Ymm(vmm_tmp.getIdx()), vmm_tmp, 1);
|
|
vinserti64x4(vmm_zp_b_val1, zmm_zp_b_val1, Ymm(vmm_tmp.getIdx()), 0);
|
|
vpermd(vmm_zp_b_val0, vmm_permd, vmm_zp_b_val0);
|
|
vpermd(vmm_zp_b_val1, vmm_permd, vmm_zp_b_val1);
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::copy_block(
|
|
const int nrows, int ncolumns, bool zeropad) {
|
|
const int columns_tail = ncolumns % n_blk_step;
|
|
if (columns_tail > 0 && columns_tail < n_blk_step) {
|
|
const auto regw_tmp = reg_tmp.cvt32();
|
|
const auto tail_mask = (1 << columns_tail) - 1;
|
|
mov(regw_tmp, tail_mask);
|
|
kmovw(kTail, regw_tmp);
|
|
}
|
|
|
|
static constexpr int blk_sz = k_blk_step;
|
|
const int max_regs_available = isa_num_vregs(conf_->isa) - reserved_regs_;
|
|
const int max_unroll = max_regs_available / blk_sz;
|
|
|
|
// Every load converts unroll * k_blk_step * n_blk_step
|
|
auto load = [this, nrows, ncolumns](int blk, int k, int n) {
|
|
const int k_blk = k / k_blk_step;
|
|
const auto src_vmm0 = get_vmm(blk, 0);
|
|
const auto src_vmm1 = get_vmm(blk, 1);
|
|
const dim_t offset = k_blk * src_stride_
|
|
+ (n * k_blk_step * typesize_) / src_elems_per_byte_;
|
|
const auto stride = (n_blk_step * typesize_) / src_elems_per_byte_;
|
|
auto load_addr0 = maybe_EVEX_compress_addr(reg_src, offset);
|
|
auto load_addr1 = maybe_EVEX_compress_addr(reg_src, offset + stride);
|
|
|
|
const bool is_n_tail = ncolumns - n < n_blk_step;
|
|
const bool is_k_tail = nrows - k < k_blk_step;
|
|
|
|
load_value(src_vmm0, load_addr0, vmm_permd, conf_->orig_wei_dt);
|
|
load_value(src_vmm1, load_addr1, vmm_permd, conf_->orig_wei_dt);
|
|
get_wei_scales(n, is_n_tail, is_k_tail);
|
|
get_zero_points(n, is_n_tail, is_k_tail);
|
|
decompress_and_downcvt_2reg(src_vmm0, src_vmm1, vmm_zp_b_val0,
|
|
vmm_zp_b_val1, vmm_wei_scales0, vmm_wei_scales1,
|
|
conf_->orig_wei_dt, conf_->wei_dt);
|
|
};
|
|
|
|
maybe_update_strides(nrows);
|
|
|
|
int iter = 0;
|
|
for_(int k = 0; k < nrows; k += k_blk_step)
|
|
for (int n = 0; n < ncolumns; n += n_blk_step) {
|
|
const int k_blk = k / k_blk_step;
|
|
const dim_t tr_src_off
|
|
= k_blk * tr_src_stride_ + n * k_blk_step * tr_typesize_;
|
|
const auto store_addr
|
|
= maybe_EVEX_compress_addr(reg_tr_src, tr_src_off);
|
|
const int blk_idx = iter % max_unroll;
|
|
|
|
const auto store_vmm = get_vmm(blk_idx, 0);
|
|
if (zeropad)
|
|
uni_vpxor(store_vmm, store_vmm, store_vmm);
|
|
else
|
|
load(blk_idx, k, n);
|
|
|
|
// Special case for goruped zp/scales when nrows == 1
|
|
if (is_wei_grouped_over_k_ && nrows == 1) {
|
|
save_half_block(blk_idx, store_addr);
|
|
iter++;
|
|
continue;
|
|
}
|
|
|
|
uni_vmovups(store_addr, store_vmm);
|
|
|
|
iter++;
|
|
}
|
|
}
|
|
|
|
template <typename Vmm>
|
|
void jit_brgemm_matmul_copy_b_cvt_bf16_t<Vmm>::generate() {
|
|
assert(tr_typesize_ == sizeof(bfloat16_t));
|
|
preamble();
|
|
|
|
init_masks();
|
|
|
|
mov(reg_src, ptr[param1 + GET_OFF(src)]);
|
|
mov(reg_tr_src, ptr[param1 + GET_OFF(tr_src)]);
|
|
mov(reg_N_blk, ptr[param1 + GET_OFF(current_N_blk)]);
|
|
mov(reg_wei_scales, ptr[param1 + GET_OFF(wei_scales_ptr)]);
|
|
mov(reg_wei_zp, ptr[param1 + GET_OFF(zp_b_value_ptr)]);
|
|
mov(reg_k_start, ptr[param1 + GET_OFF(current_K_start)]);
|
|
|
|
load_common_zp_value(vmm_zp_b_val0, reg_wei_zp);
|
|
load_common_zp_value(vmm_zp_b_val1, reg_wei_zp);
|
|
load_common_scale_value(vmm_wei_scales0, reg_wei_scales);
|
|
load_common_scale_value(vmm_wei_scales1, reg_wei_scales);
|
|
|
|
auto compute_K_loop_body = [&](const reg64_t ®_K, int ncolumns,
|
|
bool zeropad) {
|
|
// Compute special K-loop for per-k attributes
|
|
// Only when k_group_size < k_blk_step
|
|
// Otherwise default K-loop is used
|
|
if (is_wei_grouped_over_k_) {
|
|
const int k_group_size = conf_->is_wei_zp_per_k
|
|
? conf_->wei_zp_k_gsize
|
|
: conf_->wei_scales_k_gsize;
|
|
if (k_group_size == 1) {
|
|
if (zeropad) return;
|
|
copy_block(k_group_size, ncolumns, /*zeropad= */ false);
|
|
return;
|
|
}
|
|
}
|
|
const int k_unroll = 8;
|
|
|
|
Label K_loop_unrolled, K_loop_single, K_loop_tail_or_done;
|
|
cmp(reg_K, k_unroll * k_blk_step);
|
|
jl(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_unrolled);
|
|
copy_block(k_unroll * k_blk_step, ncolumns, zeropad);
|
|
add(reg_src, k_unroll * src_stride_);
|
|
add(reg_tr_src, k_unroll * tr_src_stride_);
|
|
|
|
sub(reg_K, k_unroll * k_blk_step);
|
|
cmp(reg_K, k_unroll * k_blk_step);
|
|
jge(K_loop_unrolled, T_NEAR);
|
|
|
|
L(K_loop_single);
|
|
cmp(reg_K, k_blk_step);
|
|
jl(K_loop_tail_or_done, T_NEAR);
|
|
|
|
copy_block(k_blk_step, ncolumns, zeropad);
|
|
add(reg_src, src_stride_);
|
|
add(reg_tr_src, tr_src_stride_);
|
|
|
|
sub(reg_K, k_blk_step);
|
|
jmp(K_loop_single, T_NEAR);
|
|
|
|
L(K_loop_tail_or_done);
|
|
|
|
const int k_blk_tail = conf_->K % k_blk_step;
|
|
if (k_blk_tail > 0) {
|
|
Label K_loop_done;
|
|
cmp(reg_K, 0);
|
|
jle(K_loop_done, T_NEAR);
|
|
copy_block(k_blk_tail, ncolumns, zeropad);
|
|
add(reg_tr_src, tr_src_stride_);
|
|
sub(reg_K, k_blk_tail);
|
|
L(K_loop_done);
|
|
}
|
|
};
|
|
|
|
auto compute_K_loop = [&](const int ncolumns) {
|
|
mov(reg_src_back, reg_src);
|
|
mov(reg_tr_src_back, reg_tr_src);
|
|
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_iters)]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, false);
|
|
mov(reg_K_iters, ptr[param1 + GET_OFF(current_K_pad)]);
|
|
compute_K_loop_body(reg_K_iters, ncolumns, true);
|
|
|
|
mov(reg_src, reg_src_back);
|
|
mov(reg_tr_src, reg_tr_src_back);
|
|
};
|
|
|
|
Label done;
|
|
cmp(reg_N_blk, 0);
|
|
jle(done, T_NEAR);
|
|
|
|
if (conf_->LDB2 != 0) {
|
|
Label main_N_loop, main_N_loop_tail;
|
|
int tail = conf_->N % conf_->LDB;
|
|
|
|
if (tail != 0) {
|
|
cmp(reg_N_blk, conf_->LDB);
|
|
jl(main_N_loop_tail, T_NEAR);
|
|
}
|
|
|
|
L(main_N_loop);
|
|
compute_K_loop(conf_->LDB);
|
|
add(reg_src, conf_->LDB2 * typesize_);
|
|
add(reg_tr_src, conf_->LDB2 * tr_typesize_);
|
|
|
|
sub(reg_N_blk, conf_->LDB);
|
|
cmp(reg_N_blk, conf_->LDB);
|
|
jge(main_N_loop, T_NEAR);
|
|
|
|
if (tail != 0) {
|
|
L(main_N_loop_tail);
|
|
cmp(reg_N_blk, 0);
|
|
jle(done, T_NEAR);
|
|
compute_K_loop(tail);
|
|
}
|
|
|
|
} else {
|
|
if (conf_->N_tail > 0) {
|
|
Label main_N_blk;
|
|
cmp(reg_N_blk, conf_->N_blk);
|
|
je(main_N_blk, T_NEAR);
|
|
compute_K_loop(conf_->N_tail);
|
|
jmp(done, T_NEAR);
|
|
|
|
L(main_N_blk);
|
|
}
|
|
|
|
compute_K_loop(conf_->N_blk);
|
|
}
|
|
L(done);
|
|
postamble();
|
|
}
|
|
|
|
template struct jit_brgemm_matmul_copy_b_cvt_bf16_t<Zmm>;
|
|
status_t create_brgemm_matmul_copy_b(
|
|
std::unique_ptr<jit_brgemm_matmul_copy_b_t> ©_ker,
|
|
const brgemm_matmul_conf_t *conf) {
|
|
const bool is_bf16
|
|
= everyone_is(data_type::bf16, conf->src_dt, conf->wei_dt);
|
|
const bool is_f32 = everyone_is(data_type::f32, conf->src_dt, conf->wei_dt);
|
|
// Note: f16 support through avx512_core_fp16 sets src_dt and wei_dt as f32
|
|
// to imply upconverting. So, the assumption is `is_f16` below evaluates to
|
|
// `false` on avx512_core_fp16.
|
|
const bool is_f16 = everyone_is(data_type::f16, conf->src_dt, conf->wei_dt);
|
|
if (conf->transposed_B) {
|
|
if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_transposed_t<Zmm>(conf)));
|
|
else {
|
|
assert(is_superset(conf->isa, avx2));
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_transposed_t<Ymm>(conf)));
|
|
}
|
|
} else {
|
|
if ((conf->is_bf16_with_int_wei
|
|
|| (conf->is_f16_with_int_wei
|
|
&& conf->isa != avx512_core_fp16))
|
|
&& conf->blocked_B) {
|
|
if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_cvt_bf16_t<Zmm>(conf)));
|
|
else {
|
|
assert(!"Unsupported isa for bf16_with_int_wei");
|
|
return status::unimplemented;
|
|
}
|
|
} else if (is_bf16 || is_f16 || conf->is_bf32
|
|
|| (conf->is_f16_with_int_wei
|
|
&& conf->isa != avx512_core_fp16)) {
|
|
if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_bf16_t<Zmm>(conf)));
|
|
else
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_bf16_t<Ymm>(conf)));
|
|
} else if (is_f32
|
|
|| (conf->isa == avx512_core_fp16
|
|
&& conf->orig_wei_dt == data_type::f16)) {
|
|
// See the note above why `orig_wei_dt` is used.
|
|
if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_f32_t<Zmm>(conf)));
|
|
else
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_b_f32_t<Ymm>(conf)));
|
|
} else {
|
|
if (mayiuse(avx512_core_amx))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_amx_brgemm_matmul_copy_b_int8_t(conf)));
|
|
else if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_avx512_core_brgemm_matmul_copy_b_int8_t(conf)));
|
|
else {
|
|
// TODO: jit_avx2_vnni_brgemm_matmul_copy_b_int8_t can handle
|
|
// avx2 if no compensation is required. Consider enabling it
|
|
// for avx2 and renaming the kernel (drop "vnni" part).
|
|
const bool is_comp_required = conf->s8s8_compensation_required
|
|
|| conf->has_zero_point_a;
|
|
MAYBE_UNUSED(is_comp_required);
|
|
assert(one_of(conf->isa, avx2_vnni, avx2_vnni_2, avx2)
|
|
&& IMPLICATION(conf->isa == avx2, !is_comp_required));
|
|
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_avx2_vnni_brgemm_matmul_copy_b_int8_t(conf)));
|
|
}
|
|
}
|
|
}
|
|
|
|
return copy_ker->create_kernel();
|
|
}
|
|
|
|
status_t create_brgemm_matmul_copy_a(
|
|
std::unique_ptr<jit_brgemm_matmul_copy_a_t> ©_ker,
|
|
const brgemm_matmul_conf_t *conf) {
|
|
if (conf->transposed_A) {
|
|
if (utils::one_of(conf->src_dt, data_type::s8, data_type::u8))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_a_transposed_int8_impl_t(conf)));
|
|
else if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_a_transposed_impl_t<Zmm>(conf)));
|
|
else
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_a_transposed_impl_t<Ymm>(conf)));
|
|
} else {
|
|
if (is_superset(conf->isa, avx512_core))
|
|
CHECK(safe_ptr_assign(
|
|
copy_ker, new jit_brgemm_matmul_copy_a_impl_t<Zmm>(conf)));
|
|
else {
|
|
if (is_superset(conf->isa, avx2)) {
|
|
CHECK(safe_ptr_assign(copy_ker,
|
|
new jit_brgemm_matmul_copy_a_impl_t<Ymm>(conf)));
|
|
} else {
|
|
assert(!"Unsupported isa for jit_brgemm_matmul_copy_a_impl_t");
|
|
return status::unimplemented;
|
|
}
|
|
}
|
|
}
|
|
|
|
return copy_ker->create_kernel();
|
|
}
|
|
|
|
} // namespace matmul
|
|
} // namespace x64
|
|
} // namespace cpu
|
|
} // namespace impl
|
|
} // namespace dnnl
|