mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"
This reverts commit 4b82251011f85f9d1395b451d61e976af844d9b1. Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks lots of internal build ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2555953189))
This commit is contained in:
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -131,6 +131,3 @@
|
||||
path = third_party/composable_kernel
|
||||
url = https://github.com/ROCm/composable_kernel.git
|
||||
branch = develop
|
||||
[submodule "third_party/kleidiai"]
|
||||
path = third_party/kleidiai
|
||||
url = https://git.gitlab.arm.com/kleidi/kleidiai.git
|
||||
|
@ -254,7 +254,6 @@ filegroup(
|
||||
# target that generates these sources...
|
||||
)
|
||||
|
||||
# TODO: Enable support for KleidiAI bazel build
|
||||
header_template_rule(
|
||||
name = "aten_src_ATen_config",
|
||||
src = "aten/src/ATen/Config.h.in",
|
||||
@ -274,7 +273,6 @@ header_template_rule(
|
||||
"@AT_PARALLEL_NATIVE@": "1",
|
||||
"@AT_BLAS_F2C@": "0",
|
||||
"@AT_BLAS_USE_CBLAS_DOT@": "1",
|
||||
"@AT_KLEIDIAI_ENABLED@": "0",
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -377,8 +377,6 @@ cmake_dependent_option(
|
||||
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
|
||||
OFF "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
|
||||
"CPU_AARCH64" OFF)
|
||||
|
||||
option(USE_MIMALLOC "Use mimalloc" OFF)
|
||||
# Enable third party mimalloc library to improve memory allocation performance
|
||||
@ -420,8 +418,6 @@ endif()
|
||||
if(WIN32)
|
||||
set(USE_TENSORPIPE OFF)
|
||||
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
|
||||
set(USE_KLEIDIAI OFF)
|
||||
message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF")
|
||||
|
||||
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
|
||||
find_library(
|
||||
@ -671,9 +667,6 @@ if(ANDROID
|
||||
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
|
||||
set(BUILD_LAZY_TS_BACKEND OFF)
|
||||
|
||||
set(USE_KLEIDIAI OFF)
|
||||
message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF")
|
||||
|
||||
# Set -ffunction-sections and -fdata-sections so that each method has its own
|
||||
# text section. This allows the linker to remove unused section when the flag
|
||||
# -Wl,-gc-sections is provided at link time.
|
||||
|
@ -309,12 +309,6 @@ local_repository(
|
||||
path = "third_party/gemmlowp/gemmlowp",
|
||||
)
|
||||
|
||||
local_repository(
|
||||
name = "kleidiai",
|
||||
path = "third_party/kleidiai",
|
||||
repo_mapping = {"@com_google_googletest": "@com_google_benchmark"},
|
||||
)
|
||||
|
||||
### Unused repos start
|
||||
|
||||
# `unused` repos are defined to hide bazel files from submodules of submodules.
|
||||
|
@ -199,10 +199,6 @@ endif()
|
||||
# XNNPACK
|
||||
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
||||
|
||||
# KLEIDIAI
|
||||
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
|
||||
file(GLOB native_kleidiai_h "native/kleidiai/*.h")
|
||||
|
||||
# Add files needed from jit folders
|
||||
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
|
||||
append_filelist("jit_core_sources" ATen_CORE_SRCS)
|
||||
@ -232,10 +228,6 @@ endif()
|
||||
if(AT_MKL_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
|
||||
endif()
|
||||
if(AT_KLEIDIAI_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
|
||||
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
|
||||
endif()
|
||||
if(AT_MKLDNN_ENABLED)
|
||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
|
||||
endif()
|
||||
@ -619,7 +611,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
|
||||
|
||||
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
|
||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
|
||||
# Metal
|
||||
if(USE_PYTORCH_METAL_EXPORT)
|
||||
# Add files needed from exporting metal models(optimized_for_mobile)
|
||||
|
@ -19,4 +19,3 @@
|
||||
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
|
||||
#define AT_BLAS_F2C() @AT_BLAS_F2C@
|
||||
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
|
||||
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
|
||||
|
@ -376,10 +376,6 @@ bool Context::hasMKLDNN() {
|
||||
#endif
|
||||
}
|
||||
|
||||
bool Context::hasKleidiAI() {
|
||||
return AT_KLEIDIAI_ENABLED();
|
||||
}
|
||||
|
||||
bool Context::hasOpenMP() {
|
||||
#ifdef _OPENMP
|
||||
return true;
|
||||
|
@ -118,7 +118,6 @@ class TORCH_API Context {
|
||||
|
||||
static bool hasOpenMP();
|
||||
static bool hasMKL();
|
||||
static bool hasKleidiAI();
|
||||
static bool hasLAPACK();
|
||||
static bool hasMKLDNN();
|
||||
static bool hasMAGMA() {
|
||||
@ -539,10 +538,6 @@ inline bool hasMKL() {
|
||||
return globalContext().hasMKL();
|
||||
}
|
||||
|
||||
inline bool hasKleidiAI() {
|
||||
return globalContext().hasKleidiAI();
|
||||
}
|
||||
|
||||
inline bool hasLAPACK() {
|
||||
return globalContext().hasLAPACK();
|
||||
}
|
||||
|
@ -33,8 +33,6 @@
|
||||
#include <ATen/ops/_addmm_activation_native.h>
|
||||
#include <ATen/ops/_compute_linear_combination_native.h>
|
||||
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
|
||||
#include <ATen/ops/_dyn_quant_matmul_4bit_native.h>
|
||||
#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h>
|
||||
#include <ATen/ops/_int_mm_native.h>
|
||||
#include <ATen/ops/_linalg_check_errors.h>
|
||||
#include <ATen/ops/_linalg_det.h>
|
||||
@ -3431,8 +3429,6 @@ Tensor kron(const Tensor& self, const Tensor& other) {
|
||||
DEFINE_DISPATCH(weight_to_int4pack_stub);
|
||||
DEFINE_DISPATCH(int4pack_mm_stub);
|
||||
DEFINE_DISPATCH(int8pack_mm_stub);
|
||||
DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub);
|
||||
DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub);
|
||||
|
||||
Tensor _convert_weight_to_int4pack_cpu(
|
||||
const Tensor& in,
|
||||
@ -3496,69 +3492,6 @@ Tensor _weight_int4pack_mm_cpu(
|
||||
return C;
|
||||
}
|
||||
|
||||
Tensor _dyn_quant_pack_4bit_weight_cpu(
|
||||
const Tensor& weights,
|
||||
const Tensor& scales_zeros,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t block_size,
|
||||
const int64_t in_features,
|
||||
const int64_t out_features) {
|
||||
TORCH_CHECK(
|
||||
weights.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple of 32, in_features [",
|
||||
in_features,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
Tensor packed_weights =
|
||||
at::empty(weights.sizes(), weights.options().dtype(at::kByte));
|
||||
dyn_quant_pack_4bit_weight_stub(
|
||||
kCPU,
|
||||
packed_weights,
|
||||
weights,
|
||||
scales_zeros,
|
||||
bias,
|
||||
out_features,
|
||||
in_features,
|
||||
block_size);
|
||||
return packed_weights;
|
||||
}
|
||||
|
||||
Tensor _dyn_quant_matmul_4bit_cpu(
|
||||
const Tensor& inp,
|
||||
const Tensor& packed_weights,
|
||||
const int64_t block_size,
|
||||
const int64_t in_features,
|
||||
const int64_t out_features) {
|
||||
auto M = inp.size(0);
|
||||
TORCH_CHECK(
|
||||
inp.dtype() == kFloat,
|
||||
__func__,
|
||||
" : expect input to be 32-bit float tensor.");
|
||||
TORCH_CHECK(
|
||||
block_size == in_features ||
|
||||
(!(block_size % 32) && !(in_features % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple of 32, in_features [",
|
||||
in_features,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
auto output = at::empty({M, out_features}, inp.options());
|
||||
dyn_quant_matmul_4bit_stub(
|
||||
kCPU,
|
||||
output,
|
||||
inp,
|
||||
packed_weights,
|
||||
M,
|
||||
out_features,
|
||||
in_features,
|
||||
block_size);
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor _weight_int8pack_mm_cpu(
|
||||
const Tensor& A,
|
||||
const Tensor& B,
|
||||
|
@ -8,19 +8,8 @@
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||
#include <ATen/native/cpu/utils.h>
|
||||
#include <c10/util/Unroll.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#else
|
||||
#include <ATen/ops/cat.h>
|
||||
#endif
|
||||
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
#include <ATen/native/kleidiai/kai_kernels.h>
|
||||
#include <cpuinfo.h>
|
||||
#endif
|
||||
#include <c10/util/Unroll.h>
|
||||
|
||||
#if (defined(_WIN32) || defined(_WIN64))
|
||||
#define RESTRICT __restrict
|
||||
@ -773,457 +762,10 @@ void int4pack_mm_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
bool can_use_kleidiai(
|
||||
const at::Tensor& scales_zeros,
|
||||
const int64_t K,
|
||||
const int64_t block_size) {
|
||||
bool ret = false;
|
||||
if (cpuinfo_has_arm_neon_dot()) {
|
||||
// The Groupwise kernel requires BFloat16 Scales and Channelwise kernel
|
||||
// requires Float32 Scales. If not provided, we will use fallback
|
||||
// implementation.
|
||||
if ((block_size == K && scales_zeros.dtype() == at::kFloat) ||
|
||||
((block_size < K && !(block_size % 32) && !(K % block_size)) &&
|
||||
scales_zeros.dtype() == at::kBFloat16)) {
|
||||
ret = true;
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||
* For matrix multiplication with a weight shape of (N x K)
|
||||
* the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2].
|
||||
*
|
||||
* For KleidiAI weight packing, the scales, biases, and Int4 quantized
|
||||
* weights are packed into a single `packed_weights` structure, optimized for
|
||||
* Arm instructions.
|
||||
*
|
||||
* In the fallback reference kernel, no special packing is required for
|
||||
* Int4 quantized weights.
|
||||
*
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
*/
|
||||
void dyn_quant_pack_4bit_weight_kernel(
|
||||
Tensor& packed_weights,
|
||||
const Tensor& weights,
|
||||
const Tensor& scales_zeros,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
packed_weights.resize_({weight_packed_size});
|
||||
kleidiai::kai_pack_int4_rhs(
|
||||
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
TORCH_CHECK(
|
||||
bias.has_value() == 0,
|
||||
__func__,
|
||||
" : Bias is unsupported in reference implementation");
|
||||
packed_weights = packed_weights.to(kFloat);
|
||||
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||
packed_weights.resize_(res.sizes()).copy_(res);
|
||||
}
|
||||
}
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
const uint8_t* rhs_qs4cx,
|
||||
const float* rhs_scales_f32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||
uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
|
||||
|
||||
// Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in
|
||||
// required format for matmul
|
||||
auto input_quant_pack_8bit_channelwise =
|
||||
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
const size_t lhs_qa8dx_stride = k;
|
||||
|
||||
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||
const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
|
||||
|
||||
float max0 = -FLT_MAX;
|
||||
float min0 = FLT_MAX;
|
||||
|
||||
// Find min/max for each channel
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
|
||||
max0 = (std::max)(src0_0, max0);
|
||||
min0 = (std::min)(src0_0, min0);
|
||||
}
|
||||
|
||||
// Maximum/minimum int8 values
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
|
||||
const float rmin0 = (std::min)(0.0f, min0);
|
||||
const float rmax0 = (std::max)(0.0f, max0);
|
||||
|
||||
const float scale0 =
|
||||
rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||
|
||||
// Reciprocal to quantize
|
||||
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
|
||||
|
||||
const float descaled_min0 = rmin0 * scale0;
|
||||
const float descaled_max0 = rmax0 * scale0;
|
||||
|
||||
const float zero_point_from_min_error0 = qmin + descaled_min0;
|
||||
const float zero_point_from_max_error0 = qmax + descaled_max0;
|
||||
|
||||
float zero_point0 =
|
||||
zero_point_from_min_error0 + zero_point_from_max_error0 > 0
|
||||
? qmin - descaled_min0
|
||||
: qmax - descaled_max0;
|
||||
|
||||
zero_point0 = (std::max)(zero_point0, qmin);
|
||||
zero_point0 = (std::min)(zero_point0, qmax);
|
||||
|
||||
// Round to nearest integer
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride;
|
||||
|
||||
// LHS offset at the beginning of the row
|
||||
*((float*)(dst_ptr)) = recip_scale0;
|
||||
dst_ptr += sizeof(float);
|
||||
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||
dst_ptr += sizeof(int32_t);
|
||||
|
||||
// Quantize the channels
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
|
||||
// Scale the values
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
|
||||
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||
v0_s32 = (std::max)(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||
v0_s32 = (std::min)(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Dynamically Quantize the float32 input to 8 bit assymetric
|
||||
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
|
||||
|
||||
const size_t lhs_stride =
|
||||
k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||
|
||||
const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2);
|
||||
|
||||
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||
const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride;
|
||||
|
||||
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
||||
// Main f32 accumulator
|
||||
int32_t iacc = 0;
|
||||
|
||||
const int8_t* lhs_ptr = lhs_ptr_start;
|
||||
const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
|
||||
|
||||
// Get the LHS quantization parameters stored at the
|
||||
// beginning of each row
|
||||
const float lhs_scale = *(const float*)lhs_ptr;
|
||||
lhs_ptr += sizeof(float);
|
||||
|
||||
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
|
||||
lhs_ptr += sizeof(int32_t);
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
// Get the LHS values
|
||||
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
|
||||
|
||||
// Get the RHS values
|
||||
const uint8_t rhs_byte = rhs_ptr[0];
|
||||
|
||||
// Unpack the RHS values
|
||||
int32_t rhs_v0 = 0;
|
||||
if ((k_idx % 2) == 0) {
|
||||
rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
|
||||
} else {
|
||||
rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
|
||||
}
|
||||
|
||||
iacc += lhs_v0 * rhs_v0;
|
||||
iacc += lhs_offset * rhs_v0;
|
||||
|
||||
lhs_ptr += 1;
|
||||
|
||||
// Increment only when k_idx is not a multiple of 2
|
||||
rhs_ptr += k_idx % 2;
|
||||
}
|
||||
|
||||
// Get the RHS scale
|
||||
const float rhs_scale = rhs_scales_f32[n_idx];
|
||||
|
||||
float main_acc = iacc * rhs_scale;
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
|
||||
// Clamp (min-max) operation
|
||||
main_acc = (std::max)(main_acc, scalar_min);
|
||||
main_acc = (std::min)(main_acc, scalar_max);
|
||||
|
||||
dst_f32[0] = main_acc;
|
||||
dst_f32 += 1;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
size_t m,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t bl,
|
||||
const float* lhs_f32,
|
||||
const uint8_t* rhs_qs4c32,
|
||||
const float* rhs_scales_fp32,
|
||||
float* dst_f32,
|
||||
float scalar_min,
|
||||
float scalar_max) {
|
||||
// Lambda for LHS quantization
|
||||
auto lhs_quant_pack = [&](size_t m,
|
||||
size_t k,
|
||||
const float* lhs_f32,
|
||||
int8_t* lhs_qa8dx) {
|
||||
const size_t dst_stride =
|
||||
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||
|
||||
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
|
||||
const float* src_ptr = lhs_f32 + row_idx * k;
|
||||
|
||||
float max0 = -FLT_MAX;
|
||||
float min0 = FLT_MAX;
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
max0 = (std::max)(src0_0, max0);
|
||||
min0 = (std::min)(src0_0, min0);
|
||||
}
|
||||
|
||||
const float qmin = (float)INT8_MIN;
|
||||
const float qmax = (float)INT8_MAX;
|
||||
|
||||
const float rmin0 = (std::min)(0.0f, min0);
|
||||
const float rmax0 = (std::max)(0.0f, max0);
|
||||
const float scale0 =
|
||||
(rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
|
||||
|
||||
const float descaled_min0 = rmin0 * scale0;
|
||||
const float descaled_max0 = rmax0 * scale0;
|
||||
|
||||
float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0)
|
||||
? qmin - descaled_min0
|
||||
: qmax - descaled_max0;
|
||||
|
||||
zero_point0 = (std::max)(zero_point0, qmin);
|
||||
zero_point0 = (std::min)(zero_point0, qmax);
|
||||
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||
|
||||
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride;
|
||||
|
||||
*((float*)(dst_ptr)) = recip_scale0;
|
||||
dst_ptr += sizeof(float);
|
||||
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||
dst_ptr += sizeof(int32_t);
|
||||
|
||||
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||
const float src0_0 = src_ptr[k_idx];
|
||||
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||
v0_s32 = (std::max)(
|
||||
(std::min)(
|
||||
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||
static_cast<int32_t>(INT8_MIN));
|
||||
dst_ptr[0] = (int8_t)v0_s32;
|
||||
dst_ptr += sizeof(int8_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
auto lhs_qa8dx_buffer = std::make_unique<int8_t[]>(
|
||||
m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS
|
||||
int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
|
||||
// Quantize and pack LHS
|
||||
lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx);
|
||||
|
||||
const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl;
|
||||
const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t);
|
||||
const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2;
|
||||
|
||||
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
|
||||
const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride;
|
||||
|
||||
for (size_t col_idx = 0; col_idx < n; ++col_idx) {
|
||||
float main_acc = 0.0f;
|
||||
const int8_t* lhs_ptr = lhs_ptr_start;
|
||||
const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride;
|
||||
|
||||
const float lhs_scale = *(const float*)lhs_ptr;
|
||||
lhs_ptr += sizeof(float);
|
||||
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
|
||||
lhs_ptr += sizeof(int32_t);
|
||||
|
||||
for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) {
|
||||
const float rhs_scale =
|
||||
rhs_scales_fp32[block_idx + col_idx * num_blocks_row];
|
||||
int32_t iacc = 0;
|
||||
|
||||
for (size_t i = 0; i < bl; ++i) {
|
||||
const size_t k_idx = block_idx * bl + i;
|
||||
if (k_idx >= k) {
|
||||
break;
|
||||
}
|
||||
|
||||
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
|
||||
const uint8_t rhs_byte = rhs_ptr[0];
|
||||
int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8)
|
||||
: (((int32_t)(rhs_byte >> 4)) - 8);
|
||||
|
||||
iacc += lhs_v0 * rhs_v0;
|
||||
iacc += lhs_offset * rhs_v0;
|
||||
|
||||
lhs_ptr += 1;
|
||||
rhs_ptr += (k_idx % 2);
|
||||
}
|
||||
|
||||
main_acc += iacc * rhs_scale;
|
||||
}
|
||||
|
||||
main_acc = main_acc * lhs_scale;
|
||||
main_acc = (std::max)(main_acc, scalar_min);
|
||||
main_acc = (std::min)(main_acc, scalar_max);
|
||||
|
||||
dst_f32[0] = main_acc;
|
||||
dst_f32 += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||
(INT4 Weights + FP scales + FP32 Bias)
|
||||
FP32 Input Packed Buffer
|
||||
| |
|
||||
Quantize Cast
|
||||
to INT8 to INT8
|
||||
| |
|
||||
v v
|
||||
INT8 Input INT8 Weights
|
||||
\ /
|
||||
\ /
|
||||
\ /
|
||||
INT8 Matrix Multiplication
|
||||
|
|
||||
v
|
||||
FP32 Dequantized and Accumulate in FP32
|
||||
|
|
||||
v
|
||||
FP32 Final Output
|
||||
|
||||
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||
*/
|
||||
void dyn_quant_matmul_4bit_kernel(
|
||||
const Tensor& output,
|
||||
const Tensor& inp,
|
||||
const Tensor& packed_weights,
|
||||
const int64_t M,
|
||||
const int64_t N,
|
||||
const int64_t K,
|
||||
const int64_t block_size) {
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
const int64_t weight_packed_size =
|
||||
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||
if (weight_packed_size == packed_weights.numel()) {
|
||||
// KleidiAI interface intenally handles the Channelwise and groupwise
|
||||
// distinction
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||
output, inp, packed_weights, M, N, K, block_size);
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||
const auto weights_size = N * K / 2;
|
||||
// The weights needs to be in uint8_t data type after quantization
|
||||
auto extracted_weights =
|
||||
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||
auto float32_scales =
|
||||
(packed_weights.narrow(
|
||||
0, weights_size, packed_weights.size(0) - weights_size))
|
||||
.to(kFloat);
|
||||
uint8_t* rhs_4bit =
|
||||
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||
if (block_size == K) {
|
||||
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_size,
|
||||
lhs_f32,
|
||||
rhs_4bit,
|
||||
rhs_scales_f32,
|
||||
dst_f32,
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||
__func__,
|
||||
": Group size should be multiple 32 or in_features [",
|
||||
K,
|
||||
"]. Provided ",
|
||||
block_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||
REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel)
|
||||
|
||||
} // at::native
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
@ -5,34 +5,12 @@
|
||||
|
||||
namespace at::native {
|
||||
|
||||
using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
|
||||
using int4pack_mm_fn =
|
||||
void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
|
||||
using int8pack_mm_fn =
|
||||
void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
||||
using dyn_quant_pack_4bit_weight_fn = void (*)(
|
||||
Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t);
|
||||
using dyn_quant_matmul_4bit_fn = void (*)(
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const Tensor&,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t,
|
||||
const int64_t);
|
||||
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&);
|
||||
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
|
||||
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
||||
|
||||
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
|
||||
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
|
||||
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
|
||||
DECLARE_DISPATCH(
|
||||
dyn_quant_pack_4bit_weight_fn,
|
||||
dyn_quant_pack_4bit_weight_stub);
|
||||
DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub);
|
||||
|
||||
} // namespace at::native
|
||||
|
@ -1,440 +0,0 @@
|
||||
#include <ATen/native/kleidiai/kai_kernels.h>
|
||||
#include <ATen/native/kleidiai/kai_pack.h>
|
||||
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
|
||||
|
||||
#include <ATen/Parallel.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cfloat>
|
||||
#include <cmath>
|
||||
#include <unordered_map>
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
#include <cpuinfo.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
void kai_pack_int4_rhs(
|
||||
const Tensor& weight_packed,
|
||||
const Tensor& weight,
|
||||
const Tensor& scales,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
|
||||
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
|
||||
|
||||
const int64_t rhs_stride = kai_roundup(k, 2) / 2;
|
||||
const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t);
|
||||
auto& params = kernel_packet.rhs_pack_params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
params.scale_dt = kai_datatype::kai_dt_bf16;
|
||||
|
||||
kai_pack_rhs_groupwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4c32p>(
|
||||
kernel_packet,
|
||||
weight_packed,
|
||||
weight,
|
||||
scales,
|
||||
bias,
|
||||
n,
|
||||
k,
|
||||
bl,
|
||||
rhs_stride,
|
||||
scale_stride);
|
||||
}
|
||||
}
|
||||
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
size_t packed_size = n * k;
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
// Channelwise
|
||||
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
// Groupwise
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||
kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
packed_size = kernel_packet.kai_get_rhs_packed_size(
|
||||
n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16);
|
||||
}
|
||||
return packed_size;
|
||||
}
|
||||
|
||||
static void matmul_channelwise(
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet,
|
||||
size_t m_increment,
|
||||
size_t m_start,
|
||||
size_t m_per_thread,
|
||||
size_t n_start,
|
||||
size_t n_per_thread,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t dst_stride,
|
||||
size_t lhs_stride,
|
||||
uint8_t* lhs_native_mtx_f32,
|
||||
uint8_t* lhs_packed_mtx_qa8dx,
|
||||
uint8_t* rhs_packed_mtx_qs4cx,
|
||||
uint8_t* dst_act_mtx_f32) {
|
||||
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
|
||||
const float* src_ptr =
|
||||
(const float*)(lhs_native_mtx_f32 +
|
||||
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
m_start + m0, lhs_stride));
|
||||
void* lhs_packed_ptr =
|
||||
(void*)(lhs_packed_mtx_qa8dx +
|
||||
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||
0, k, mr, kr, sr));
|
||||
const void* rhs_packed_ptr =
|
||||
(const void*)((const char*)rhs_packed_mtx_qs4cx +
|
||||
kernel_packet.ukernel.get_rhs_packed_offset(n_start, k));
|
||||
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
|
||||
kernel_packet.ukernel.get_dst_offset(
|
||||
m_start + m0, n_start, dst_stride));
|
||||
|
||||
// Quantize and pack the Input
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
|
||||
|
||||
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||
kernel_packet.ukernel.run_matmul(
|
||||
m_increment,
|
||||
n_per_thread,
|
||||
k,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_ptr,
|
||||
dst_ptr,
|
||||
dst_stride,
|
||||
sizeof(float),
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
}
|
||||
}
|
||||
|
||||
static void matmul_groupwise(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
|
||||
const size_t m,
|
||||
const size_t num_n_per_thread,
|
||||
const size_t n_start,
|
||||
const size_t k,
|
||||
const size_t bl,
|
||||
const size_t dst_stride,
|
||||
const void* lhs_ptr,
|
||||
uint8_t* rhs_packed,
|
||||
uint8_t* dst_data) {
|
||||
const size_t rhs_packed_offset =
|
||||
ukernel.get_rhs_packed_offset(n_start, k, bl);
|
||||
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
|
||||
|
||||
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
|
||||
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
|
||||
|
||||
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||
ukernel.run_matmul(
|
||||
m,
|
||||
num_n_per_thread,
|
||||
k,
|
||||
bl,
|
||||
lhs_ptr,
|
||||
rhs_ptr,
|
||||
dst_ptr,
|
||||
dst_stride,
|
||||
sizeof(float),
|
||||
-FLT_MAX,
|
||||
FLT_MAX);
|
||||
}
|
||||
|
||||
struct ThreadDivision {
|
||||
int64_t num_threads_x;
|
||||
int64_t num_threads_y;
|
||||
bool use_gemm; // True if GEMM is selected, false if GEMV is used
|
||||
};
|
||||
|
||||
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
|
||||
n |= n >> 1;
|
||||
n |= n >> 2;
|
||||
n |= n >> 4;
|
||||
n |= n >> 8;
|
||||
n |= n >> 16;
|
||||
return n - (n >> 1);
|
||||
}
|
||||
|
||||
inline static void adjust_max_threads(int64_t& max_threads) {
|
||||
// We would not like to round down to nearest power of 2 always
|
||||
// There can be possible thread split combination between powers of 2 for odd
|
||||
// shapes
|
||||
// TODO:: Decide better strategy based on hint of input and weight shapes
|
||||
max_threads = round_down_to_power_of_2(max_threads);
|
||||
}
|
||||
|
||||
static std::pair<int64_t, int64_t> split_2d(const int64_t max_threads) {
|
||||
int64_t sqrt_threads = std::sqrt(max_threads);
|
||||
|
||||
for (int64_t i = sqrt_threads; i >= 1; --i) {
|
||||
if (max_threads % i == 0) {
|
||||
return {i, max_threads / i};
|
||||
}
|
||||
}
|
||||
|
||||
return {1, max_threads}; // Theres still a possibility of 1D blocking when
|
||||
// calling GEMM kernel
|
||||
}
|
||||
|
||||
inline static ThreadDivision get_thread_division(
|
||||
int64_t max_threads,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t gemm_m_step,
|
||||
const int64_t gemm_n_step,
|
||||
const int64_t gemv_m_step,
|
||||
const int64_t gemv_n_step) {
|
||||
adjust_max_threads(max_threads);
|
||||
ThreadDivision division{1, 1, false};
|
||||
|
||||
// Split threads 2D for GEMM
|
||||
if (m % gemm_m_step == 0 && n % gemm_n_step == 0) {
|
||||
while (max_threads > 0) {
|
||||
auto [num_thread_y, num_thread_x] = split_2d(max_threads);
|
||||
if (m % num_thread_y == 0 && n % num_thread_x == 0) {
|
||||
int64_t m_per_thread = m / num_thread_y;
|
||||
int64_t n_per_thread = n / num_thread_x;
|
||||
if (m_per_thread % gemm_m_step == 0 &&
|
||||
n_per_thread % gemm_n_step == 0) {
|
||||
division = {num_thread_x, num_thread_y, true};
|
||||
return division;
|
||||
}
|
||||
}
|
||||
max_threads -= 2;
|
||||
}
|
||||
}
|
||||
// Split threads 1D for GEMV
|
||||
if (n % gemv_n_step == 0) {
|
||||
for (; max_threads > 0; max_threads -= 2) {
|
||||
if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) {
|
||||
division.num_threads_x = max_threads;
|
||||
return division;
|
||||
}
|
||||
}
|
||||
}
|
||||
return division;
|
||||
}
|
||||
|
||||
static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
kai_kernel_id id = kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
|
||||
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||
id =
|
||||
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
|
||||
}
|
||||
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
|
||||
|
||||
const auto& ukernel = kernel_packet.ukernel;
|
||||
|
||||
const size_t mr = ukernel.get_mr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
const size_t n_step = ukernel.get_n_step();
|
||||
int64_t total_threads = at::get_num_threads();
|
||||
int64_t num_threads_x = 1;
|
||||
adjust_max_threads(total_threads);
|
||||
// Split threads 1D only for now
|
||||
if (n % n_step == 0) {
|
||||
for (; total_threads > 0; total_threads -= 2) {
|
||||
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
|
||||
num_threads_x = total_threads;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const size_t num_n_per_thread = n / num_threads_x;
|
||||
|
||||
const size_t dst_stride = n * sizeof(float);
|
||||
float* lhs = reinterpret_cast<float*>(input.data_ptr());
|
||||
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||
|
||||
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||
|
||||
// Quantize and pack the Input
|
||||
kernel_packet.kai_run_lhs_quant_pack(
|
||||
m,
|
||||
k,
|
||||
mr,
|
||||
kr,
|
||||
sr,
|
||||
0,
|
||||
(const float*)lhs,
|
||||
k * sizeof(float),
|
||||
(void*)lhs_packed.get());
|
||||
|
||||
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
|
||||
for (const auto x : c10::irange(begin, end)) {
|
||||
matmul_groupwise(
|
||||
std::ref(ukernel),
|
||||
m,
|
||||
num_n_per_thread,
|
||||
x * num_n_per_thread,
|
||||
k,
|
||||
bl,
|
||||
dst_stride,
|
||||
lhs_packed.get(),
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
// Kernel IDs for GEMM and GEMV
|
||||
kai_kernel_id gemm_id =
|
||||
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm;
|
||||
kai_kernel_id gemv_id =
|
||||
kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod;
|
||||
|
||||
// Get the total number of threads available and choose GEMM or GEMV steps
|
||||
const int64_t total_threads = at::get_num_threads();
|
||||
auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
|
||||
if (cpuinfo_has_arm_i8mm()) {
|
||||
gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id);
|
||||
}
|
||||
auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
|
||||
|
||||
// Retrieve m_step and n_step values from GEMM and GEMV kernels
|
||||
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
|
||||
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
|
||||
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
|
||||
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
|
||||
// Determine threading and kernel type
|
||||
ThreadDivision division = get_thread_division(
|
||||
total_threads,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
gemm_m_step,
|
||||
gemm_n_step,
|
||||
gemv_m_step,
|
||||
gemv_n_step);
|
||||
// Select appropriate kernel packet based on the chosen kernel type
|
||||
auto& kernel_packet =
|
||||
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
|
||||
|
||||
// Thread blocking parameters
|
||||
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||
const size_t nr = kernel_packet.ukernel.get_nr();
|
||||
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||
const size_t m_increment = kernel_packet.ukernel.get_m_step();
|
||||
const size_t n_per_thread = n / division.num_threads_x;
|
||||
const size_t m_per_thread = m / division.num_threads_y;
|
||||
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
|
||||
const size_t dst_stride = n * sizeof(float);
|
||||
const size_t lhs_stride = k * sizeof(float);
|
||||
|
||||
const size_t lhs_packed_size =
|
||||
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
|
||||
|
||||
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
|
||||
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
|
||||
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||
|
||||
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
|
||||
for (const auto i : c10::irange(begin, end)) {
|
||||
size_t y = i / division.num_threads_x;
|
||||
size_t x = i % division.num_threads_x;
|
||||
uint8_t* lhs_packed_ptr =
|
||||
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
|
||||
matmul_channelwise(
|
||||
std::ref(kernel_packet),
|
||||
m_increment,
|
||||
y * m_per_thread,
|
||||
m_per_thread,
|
||||
x * n_per_thread,
|
||||
n_per_thread,
|
||||
n,
|
||||
k,
|
||||
mr,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
dst_stride,
|
||||
lhs_stride,
|
||||
lhs_native_mtx_f32,
|
||||
lhs_packed_ptr,
|
||||
rhs_packed_mtx_qs4cx,
|
||||
dst_act_mtx_f32);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl) {
|
||||
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||
if (bl == k) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||
output, input, weight, m, n, k);
|
||||
} else if (!(bl % 32) && !(k % bl)) {
|
||||
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||
output, input, weight, m, n, k, bl);
|
||||
}
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
@ -1,42 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
/**
|
||||
* @brief Rearranges the quantized weight to support kleidiai inference
|
||||
* @param bl Groupsize for quantization should be multiple of 32
|
||||
*/
|
||||
void kai_pack_int4_rhs(
|
||||
const Tensor& weight_packed,
|
||||
const Tensor& weight,
|
||||
const Tensor& scales,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
|
||||
/**
|
||||
* @brief Outputs the buffer size for the packed weights
|
||||
* @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise
|
||||
*/
|
||||
size_t kai_pack_rhs_int4_size(
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
|
||||
/**
|
||||
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||
*/
|
||||
void kai_quant_pack_lhs_int4_mm(
|
||||
const Tensor& output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const int64_t m,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl);
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
@ -1,106 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/Config.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <torch/library.h>
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
template <typename T>
|
||||
void kai_pack_rhs_groupwise_int4(
|
||||
T& kernel,
|
||||
const Tensor& weight_packed,
|
||||
const Tensor& weight,
|
||||
const Tensor& scales,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t n,
|
||||
const int64_t k,
|
||||
const int64_t bl,
|
||||
const int64_t rhs_stride,
|
||||
const int64_t scale_stride) {
|
||||
const auto& ukernel = kernel.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
auto scales_data = scales.const_data_ptr();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
}
|
||||
|
||||
if (scales_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
/*num_groups=*/1,
|
||||
n,
|
||||
k,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
bl,
|
||||
(const uint8_t*)(weight_data),
|
||||
rhs_stride,
|
||||
bias_ptr,
|
||||
scales_data,
|
||||
scale_stride,
|
||||
weight_packed_data,
|
||||
0,
|
||||
¶ms);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void kai_pack_rhs_channelwise_int4(
|
||||
T& kernel,
|
||||
const Tensor& weight_packed,
|
||||
const Tensor& weight,
|
||||
const Tensor& scales,
|
||||
const std::optional<Tensor>& bias,
|
||||
const int64_t n,
|
||||
const int64_t k) {
|
||||
const auto& ukernel = kernel.ukernel;
|
||||
const size_t nr = ukernel.get_nr();
|
||||
const size_t kr = ukernel.get_kr();
|
||||
const size_t sr = ukernel.get_sr();
|
||||
auto weight_packed_data =
|
||||
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||
const auto scales_data = scales.data_ptr<float>();
|
||||
|
||||
if (weight_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||
}
|
||||
|
||||
if (scales_data == nullptr) {
|
||||
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||
}
|
||||
|
||||
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||
auto& params = kernel.rhs_pack_params;
|
||||
|
||||
kernel.kai_run_rhs_pack(
|
||||
/*num_groups=*/1,
|
||||
n,
|
||||
k,
|
||||
nr,
|
||||
kr,
|
||||
sr,
|
||||
(const uint8_t*)(weight_data),
|
||||
(const float*)(bias_ptr),
|
||||
(const float*)(scales_data),
|
||||
weight_packed_data,
|
||||
0,
|
||||
¶ms);
|
||||
}
|
||||
|
||||
} // namespace at::native::kleidiai
|
||||
|
||||
#endif
|
@ -1,72 +0,0 @@
|
||||
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
|
||||
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
// Kernel Mapping - Groupwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4c32p> groupwise_8bit_4bit_kernels =
|
||||
{{kai_kernel_id::
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
kai_kernel_id id) {
|
||||
return groupwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
|
||||
// Kernel Mapping - Channelwise
|
||||
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4cxp> channelwise_8bit_4bit_kernels =
|
||||
{{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}},
|
||||
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}};
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||
const kai_kernel_id id) {
|
||||
return channelwise_8bit_4bit_kernels.at(id);
|
||||
}
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
@ -1,144 +0,0 @@
|
||||
#pragma once
|
||||
#include <ATen/Config.h>
|
||||
#include <unordered_map>
|
||||
#if AT_KLEIDIAI_ENABLED()
|
||||
|
||||
#include <kai_common.h>
|
||||
#include <kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
|
||||
#include <kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||
#include <kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||
#include <kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||
|
||||
namespace at::native::kleidiai {
|
||||
|
||||
enum class kai_kernel_id {
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||
0, // Groupwise 4 bit GEMV
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||
1, // Groupwise 4 bit GEMM
|
||||
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||
2, // Channelwise 4 bit GEMV
|
||||
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||
3 // Channelwise 4 bit GEMM
|
||||
};
|
||||
|
||||
// Channelwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const float* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
const uint8_t* rhs,
|
||||
const float* bias,
|
||||
const float* scale,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||
|
||||
// Groupwise Kernel mapping
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params;
|
||||
size_t (*kai_get_lhs_packed_size)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr);
|
||||
size_t (*kai_get_rhs_packed_size)(
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t bl,
|
||||
enum kai_datatype scale_dt);
|
||||
void (*kai_run_lhs_quant_pack)(
|
||||
size_t m,
|
||||
size_t k,
|
||||
size_t mr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t m_idx_start,
|
||||
const float* lhs,
|
||||
size_t lhs_stride,
|
||||
void* lhs_packed);
|
||||
void (*kai_run_rhs_pack)(
|
||||
size_t num_groups,
|
||||
size_t n,
|
||||
size_t k,
|
||||
size_t nr,
|
||||
size_t kr,
|
||||
size_t sr,
|
||||
size_t bl,
|
||||
const uint8_t* rhs,
|
||||
size_t rhs_stride,
|
||||
const float* bias,
|
||||
const void* scale,
|
||||
size_t scale_stride,
|
||||
void* rhs_packed,
|
||||
size_t extra_bytes,
|
||||
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||
|
||||
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||
: ukernel(kernel),
|
||||
kai_get_lhs_packed_size(
|
||||
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_get_rhs_packed_size(
|
||||
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||
};
|
||||
|
||||
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||
const kai_kernel_id id);
|
||||
|
||||
} // namespace at::native::kleidiai
|
||||
#endif
|
@ -4177,14 +4177,6 @@
|
||||
dispatch:
|
||||
CPU: _weight_int4pack_mm_cpu
|
||||
|
||||
- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor
|
||||
dispatch:
|
||||
CPU: _dyn_quant_pack_4bit_weight_cpu
|
||||
|
||||
- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor
|
||||
dispatch:
|
||||
CPU: _dyn_quant_matmul_4bit_cpu
|
||||
|
||||
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
|
||||
dispatch:
|
||||
CPU: _weight_int8pack_mm_cpu
|
||||
|
@ -1070,7 +1070,6 @@ def define_buck_targets(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO: Enable support for KleidiAI bazel build
|
||||
# @lint-ignore BUCKLINT
|
||||
fb_native.genrule(
|
||||
name = "generate_aten_config",
|
||||
@ -1123,9 +1122,6 @@ def define_buck_targets(
|
||||
"--replace",
|
||||
"@AT_BLAS_USE_CBLAS_DOT@",
|
||||
"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
|
||||
"--replace",
|
||||
"@AT_KLEIDIAI_ENABLED@",
|
||||
"0",
|
||||
]),
|
||||
outs = {
|
||||
"Config.h": ["Config.h"],
|
||||
|
@ -152,7 +152,6 @@ endif()
|
||||
set(AT_MKLDNN_ACL_ENABLED 0)
|
||||
set(AT_MKLDNN_ENABLED 0)
|
||||
set(AT_MKL_ENABLED 0)
|
||||
set(AT_KLEIDIAI_ENABLED 0)
|
||||
# setting default preferred BLAS options if not already present.
|
||||
if(NOT INTERN_BUILD_MOBILE)
|
||||
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
|
||||
@ -1481,35 +1480,6 @@ if(NOT INTERN_BUILD_MOBILE)
|
||||
message("disabling MKLDNN because USE_MKLDNN is not set")
|
||||
endif()
|
||||
|
||||
if(USE_KLEIDIAI)
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
|
||||
message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
|
||||
endif()
|
||||
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
|
||||
message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
|
||||
endif()
|
||||
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
|
||||
set(AT_KLEIDIAI_ENABLED 1)
|
||||
set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests
|
||||
set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai")
|
||||
add_subdirectory(${KLEIDIAI_SRC})
|
||||
set(KLEIDIAI_INCLUDE_DIRS
|
||||
${KLEIDIAI_SRC}/
|
||||
${KLEIDIAI_SRC}/kai/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/
|
||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/
|
||||
)
|
||||
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai)
|
||||
# Recover build options.
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
|
||||
endif()
|
||||
|
||||
if(UNIX AND NOT APPLE)
|
||||
include(CheckLibraryExists)
|
||||
# https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830
|
||||
|
@ -153,9 +153,6 @@ function(caffe2_print_configuration_summary)
|
||||
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
|
||||
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
|
||||
endif()
|
||||
if(${USE_KLEIDIAI})
|
||||
message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}")
|
||||
endif()
|
||||
message(STATUS " USE_UCC : ${USE_UCC}")
|
||||
if(${USE_UCC})
|
||||
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
|
||||
|
@ -97,10 +97,6 @@ else()
|
||||
append_torchlib_if_found(microkernels-prod)
|
||||
endif()
|
||||
|
||||
if(@USE_KLEIDIAI@)
|
||||
append_torchlib_if_found(kleidiai)
|
||||
endif()
|
||||
|
||||
append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
|
||||
append_torchlib_if_found(onnx onnx_proto)
|
||||
|
||||
|
@ -203,7 +203,6 @@ torch.backends.openmp
|
||||
.. add anything to the rendered page for now.
|
||||
.. py:module:: torch.backends.quantized
|
||||
.. py:module:: torch.backends.xnnpack
|
||||
.. py:module:: torch.backends.kleidiai
|
||||
|
||||
|
||||
torch.backends.opt_einsum
|
||||
|
1
setup.py
1
setup.py
@ -1221,7 +1221,6 @@ def main():
|
||||
"include/ATen/native/cuda/*.cuh",
|
||||
"include/ATen/native/hip/*.h",
|
||||
"include/ATen/native/hip/*.cuh",
|
||||
"include/ATen/native/kleidiai/*.h",
|
||||
"include/ATen/native/mps/*.h",
|
||||
"include/ATen/native/mkldnn/xpu/*.h",
|
||||
"include/ATen/native/mkldnn/xpu/detail/*.h",
|
||||
|
@ -92,8 +92,6 @@ aten::_dimI
|
||||
aten::_dimV
|
||||
aten::_dirichlet_grad
|
||||
aten::_dirichlet_grad.out
|
||||
aten::_dyn_quant_matmul_4bit
|
||||
aten::_dyn_quant_pack_4bit_weight
|
||||
aten::_efficient_attention_backward
|
||||
aten::_efficient_attention_forward
|
||||
aten::_efficientzerotensor
|
||||
|
@ -76,7 +76,6 @@ from torch.testing._internal.common_device_type import (
|
||||
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
|
||||
from torch.testing._internal.common_quantization import (
|
||||
_dynamically_quantize_per_channel,
|
||||
_group_quantize_tensor_symmetric,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
DeterministicGuard,
|
||||
@ -2225,85 +2224,6 @@ class CommonTemplate:
|
||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||
self.common(fn, (a, b_int8pack, b_scales, c))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||
def test__dyn_quant_pack_4bit_weight(self):
|
||||
q_group = 32
|
||||
k = 128
|
||||
n = 128
|
||||
|
||||
torch.manual_seed(1)
|
||||
b = torch.rand((k, n), dtype=torch.float32)
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(b, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
return b_int4pack
|
||||
|
||||
self.common(fn, (b, in_features, out_features))
|
||||
|
||||
@xfail_if_triton_cpu
|
||||
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||
@skipIfRocm
|
||||
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||
def test__dyn_quant_matmul_4bit(self):
|
||||
q_group = 32
|
||||
m = 32
|
||||
k = 128
|
||||
n = 128
|
||||
|
||||
torch.manual_seed(1)
|
||||
a = torch.rand((m, k), dtype=torch.float32)
|
||||
b = torch.rand((k, n), dtype=torch.float32)
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def fn(a, q_group, in_features, out_features):
|
||||
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||
res = torch._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
return res
|
||||
|
||||
self.common(fn, (a, q_group, in_features, out_features))
|
||||
|
||||
def test_expanded_reduction(self):
|
||||
def fn(x, y):
|
||||
z = x * y
|
||||
|
@ -34,8 +34,7 @@ from torch.testing._internal.common_dtype import (
|
||||
)
|
||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
|
||||
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \
|
||||
_group_quantize_tensor_symmetric
|
||||
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
|
||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||
from torch.distributions.binomial import Binomial
|
||||
import torch.backends.opt_einsum as opt_einsum
|
||||
@ -910,6 +909,7 @@ class TestLinalg(TestCase):
|
||||
torch.randn((3, 52, 52), device=device, dtype=dtype),
|
||||
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
|
||||
|
||||
|
||||
ops = (torch.det, torch.Tensor.det,
|
||||
torch.linalg.det)
|
||||
for t in tensors:
|
||||
@ -1438,6 +1438,7 @@ class TestLinalg(TestCase):
|
||||
continue
|
||||
run_test_case(make_arg(shape), ord, dim, keepdim)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(torch.bfloat16, torch.float16)
|
||||
def test_norm_fused_type_promotion(self, device, dtype):
|
||||
@ -4343,6 +4344,7 @@ class TestLinalg(TestCase):
|
||||
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
|
||||
upper, unitriangular, transpose)
|
||||
|
||||
|
||||
@slowTest
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@ -4463,6 +4465,7 @@ class TestLinalg(TestCase):
|
||||
self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
|
||||
self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
|
||||
|
||||
|
||||
def check_single_matmul(self, x, y):
|
||||
|
||||
def assertEqual(answer, expected):
|
||||
@ -5686,6 +5689,7 @@ class TestLinalg(TestCase):
|
||||
else:
|
||||
self.assertEqual(B_, X_ @ A)
|
||||
|
||||
|
||||
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
|
||||
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
|
||||
# Non pivoting just implemented for CUDA
|
||||
@ -5718,6 +5722,7 @@ class TestLinalg(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
|
||||
f(torch.empty(1, 2, 2), pivot=False)
|
||||
|
||||
|
||||
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
|
||||
@skipCUDAIfNoMagmaAndNoCusolver
|
||||
@skipCPUIfNoLapack
|
||||
@ -5745,6 +5750,7 @@ class TestLinalg(TestCase):
|
||||
for b, n in shapes:
|
||||
yield make_arg((b, n, n)), make_arg((b, n, rhs))
|
||||
|
||||
|
||||
for A, B in gen_matrices():
|
||||
LU, pivots = torch.linalg.lu_factor(A)
|
||||
for backend in backends:
|
||||
@ -5759,6 +5765,7 @@ class TestLinalg(TestCase):
|
||||
else:
|
||||
self.assertEqual(B_left, X @ A_adj)
|
||||
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_linalg_lu_cpu_errors(self, device, dtype):
|
||||
@ -5799,6 +5806,7 @@ class TestLinalg(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
||||
torch.lu_unpack(LU, pivots)
|
||||
|
||||
|
||||
# Rectangular tests
|
||||
sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
||||
B = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
||||
@ -5815,6 +5823,7 @@ class TestLinalg(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
||||
torch.lu_unpack(LU, pivots)
|
||||
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.double)
|
||||
@ -6423,6 +6432,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
|
||||
self.assertEqual(out, y_ref)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyCUDA
|
||||
def test_matmul_45724(self, device):
|
||||
@ -6595,6 +6605,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
torch._int_mm(a_int8, b_int8, out=c_int32_result)
|
||||
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyNativeDeviceTypes
|
||||
@ -6657,6 +6668,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@onlyNativeDeviceTypes
|
||||
@ -6709,168 +6721,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("k", [64, 256])
|
||||
@parametrize("n", [32, 48, 64, 128])
|
||||
def test__dyn_quant_pack_4bit_weight(self, device, k, n):
|
||||
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
|
||||
# Weight shape is [K x N]
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA Backend is unsupported")
|
||||
|
||||
torch.manual_seed(1)
|
||||
block_size = 32
|
||||
b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||
in_features = b.size(0)
|
||||
out_features = b.size(1)
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=block_size
|
||||
)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
|
||||
)
|
||||
b_int4pack_meta = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
|
||||
)
|
||||
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
q_group = 32
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
|
||||
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
|
||||
in_features = b_float32.size(0)
|
||||
out_features = b_float32.size(1)
|
||||
|
||||
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
|
||||
return b_int4pack, b_scales_and_zeros
|
||||
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_int4pack, q_group, in_features, out_features
|
||||
):
|
||||
return torch._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
|
||||
b_float32, in_features, out_features
|
||||
)
|
||||
|
||||
dtypes = [torch.float32]
|
||||
|
||||
for dtype in dtypes:
|
||||
a = a_float32.to(dtype=dtype)
|
||||
b = b_float32.to(dtype=dtype)
|
||||
ref = torch.mm(a, b)
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
elementwise_diff = (res - ref).abs()
|
||||
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
|
||||
min=torch.finfo(ref.dtype).eps
|
||||
)
|
||||
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
|
||||
self.assertTrue(
|
||||
all_elements_within_threshold, "Some elements have error >= 0.06"
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||
@onlyNativeDeviceTypes
|
||||
@parametrize("m", [1, 32])
|
||||
@parametrize("k", [64, 128])
|
||||
@parametrize("n", [4096, 11008])
|
||||
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||
if self.device_type == "cuda":
|
||||
self.skipTest("CUDA is unsupported")
|
||||
|
||||
q_group = 32
|
||||
|
||||
torch.manual_seed(1)
|
||||
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
|
||||
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
|
||||
in_features = b_float32.size(0)
|
||||
out_features = b_float32.size(1)
|
||||
|
||||
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||
b_float32, n_bit=4, groupsize=q_group
|
||||
)
|
||||
|
||||
if q_group == in_features:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
|
||||
else:
|
||||
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
|
||||
|
||||
@torch.compile
|
||||
def dyn_quant_matmul_4bit(
|
||||
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
|
||||
):
|
||||
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||
)
|
||||
return torch._dyn_quant_matmul_4bit(
|
||||
a,
|
||||
b_int4pack,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
|
||||
res = dyn_quant_matmul_4bit(
|
||||
a_float32,
|
||||
b_uint8,
|
||||
b_scales_and_zeros,
|
||||
q_group,
|
||||
in_features,
|
||||
out_features,
|
||||
)
|
||||
ref = torch.mm(a_float32, b_float32)
|
||||
|
||||
mean_err = ((res - ref).abs() / ref).mean()
|
||||
self.assertTrue(mean_err < 0.05)
|
||||
elementwise_diff = (res - ref).abs()
|
||||
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
|
||||
min=torch.finfo(ref.dtype).eps
|
||||
)
|
||||
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
|
||||
self.assertTrue(
|
||||
all_elements_within_threshold, "Some elements have error >= 0.06"
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@parametrize("m", [32, 64])
|
||||
@parametrize("k", [32, 64])
|
||||
@ -8802,6 +8652,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
|
||||
self.assertEqual(ck_out, cpu_out)
|
||||
|
||||
|
||||
|
||||
def test_permute_matmul(self):
|
||||
a = torch.ones([2, 5, 24, 24])
|
||||
b = torch.ones([3, 2, 5, 24, 24])
|
||||
@ -8881,6 +8733,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
||||
ref = alpha * A @ B + beta * C
|
||||
self.assertEqual(rc, ref)
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.double)
|
||||
@precisionOverride({torch.float32: 1e-4})
|
||||
def test_1_sized_with_0_strided(self, device, dtype):
|
||||
|
1
third_party/kleidiai
vendored
1
third_party/kleidiai
vendored
Submodule third_party/kleidiai deleted from 202603f38a
@ -1299,7 +1299,6 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a
|
||||
|
||||
has_openmp: _bool
|
||||
has_mkl: _bool
|
||||
_has_kleidiai: _bool
|
||||
_has_mps: _bool
|
||||
has_lapack: _bool
|
||||
_has_cuda: _bool
|
||||
|
@ -1372,8 +1372,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._dim_arange",
|
||||
"torch._dirichlet_grad",
|
||||
"torch._disable_functionalization",
|
||||
"torch._dyn_quant_matmul_4bit",
|
||||
"torch._dyn_quant_pack_4bit_weight",
|
||||
"torch._efficientzerotensor",
|
||||
"torch._embedding_bag_forward_only",
|
||||
"torch._embedding_bag",
|
||||
|
@ -94,6 +94,3 @@ def register_woq_mm_ops() -> None:
|
||||
return autotune_select_algorithm(
|
||||
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
|
||||
)
|
||||
|
||||
lowering.make_fallback(aten._dyn_quant_matmul_4bit)
|
||||
lowering.make_fallback(aten._dyn_quant_pack_4bit_weight)
|
||||
|
@ -3344,161 +3344,6 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
|
||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||
|
||||
|
||||
def kai_roundup(a: int, b: int) -> int:
|
||||
return ((a + b - 1) // b) * b
|
||||
|
||||
|
||||
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
|
||||
if n_bits == 4:
|
||||
if groupsize == K: # channelwise
|
||||
# dotprod params only [1x8x32_neon_dotprod]
|
||||
kai_nr = 8
|
||||
kai_kr = 16
|
||||
kai_sr = 2
|
||||
kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
|
||||
kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
|
||||
kai_num_bytes_bias = 4 # sizeof(float)
|
||||
|
||||
def kai_k_roundedup(k, kr, sr):
|
||||
# Since we pack a float and int32 value at the end of the row,
|
||||
# we must make sure that k is a multiple of 4 for alignment
|
||||
kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
|
||||
return kai_roundup(k, kr_sr_roundedup4)
|
||||
|
||||
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||
k, nr, kr, sr
|
||||
):
|
||||
k_internal = kai_k_roundedup(k, kr, sr)
|
||||
|
||||
assert (k_internal % 2) == 0, "k_internal must be even"
|
||||
|
||||
return nr * (
|
||||
(k_internal // 2)
|
||||
+ kai_num_bytes_multiplier_rhs
|
||||
+ kai_num_bytes_sum_rhs
|
||||
+ kai_num_bytes_bias
|
||||
)
|
||||
|
||||
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||
n, k, nr, kr, sr
|
||||
):
|
||||
num_rows = kai_roundup(n, nr) // nr
|
||||
|
||||
return (
|
||||
num_rows
|
||||
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||
k, nr, kr, sr
|
||||
)
|
||||
)
|
||||
|
||||
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||
N, K, kai_nr, kai_kr, kai_sr
|
||||
)
|
||||
elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
|
||||
kai_nr = 8
|
||||
kai_kr = 16
|
||||
kai_sr = 2
|
||||
kai_num_bytes_sum_rhs = 4
|
||||
kai_num_bytes_bias = 4
|
||||
kai_nr_multiple_of = 4
|
||||
kai_bl_multiple_of = 32
|
||||
|
||||
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||
n, k, nr, kr, sr, bl
|
||||
):
|
||||
assert (bl % kr) == 0
|
||||
assert (nr % kai_nr_multiple_of) == 0
|
||||
assert (bl % kai_bl_multiple_of) == 0
|
||||
|
||||
num_rows = kai_roundup(n, nr) // nr
|
||||
|
||||
return (
|
||||
num_rows
|
||||
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||
k, nr, kr, sr, bl
|
||||
)
|
||||
)
|
||||
|
||||
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||
k, nr, kr, sr, bl
|
||||
):
|
||||
assert (bl % kr) == 0
|
||||
assert (nr % kai_nr_multiple_of) == 0
|
||||
assert (bl % kai_bl_multiple_of) == 0
|
||||
|
||||
# kr and sr are unused in the calculation
|
||||
num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
|
||||
num_blocks_per_row = kai_num_blocks_per_row(k, bl)
|
||||
num_bytes_per_block = kai_num_bytes_per_block(
|
||||
bl, num_bytes_multiplier_rhs
|
||||
)
|
||||
|
||||
return nr * (
|
||||
(num_bytes_per_block * num_blocks_per_row)
|
||||
+ kai_num_bytes_sum_rhs
|
||||
+ kai_num_bytes_bias
|
||||
)
|
||||
|
||||
# This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype
|
||||
# https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
|
||||
def kai_get_bf16_datatype_size_in_bytes():
|
||||
return 2 # 2 bytes
|
||||
|
||||
def kai_num_blocks_per_row(k, bl):
|
||||
assert (bl % kai_bl_multiple_of) == 0
|
||||
return kai_roundup(k, bl) // bl
|
||||
|
||||
def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
|
||||
assert (bl % kai_bl_multiple_of) == 0
|
||||
return (bl // 2) + num_bytes_multiplier_rhs
|
||||
|
||||
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||
N, K, kai_nr, kai_kr, kai_sr, groupsize
|
||||
)
|
||||
|
||||
|
||||
@register_meta([aten._dyn_quant_pack_4bit_weight])
|
||||
def meta__dyn_quant_pack_4bit_weight(
|
||||
weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features
|
||||
):
|
||||
torch._check(
|
||||
weights.dtype is torch.uint8,
|
||||
lambda: f"expected w to be uint8, got {weights.dtype}",
|
||||
)
|
||||
if torch.backends.kleidiai.is_available() and (
|
||||
(block_size == in_features and scales_zeros.dtype == torch.float)
|
||||
or (
|
||||
block_size < in_features
|
||||
and block_size % 32 == 0
|
||||
and in_features % block_size == 0
|
||||
and scales_zeros.dtype == torch.bfloat16
|
||||
)
|
||||
):
|
||||
packed_weight_size = get_kai_packed_weight_size(
|
||||
4, out_features, in_features, block_size
|
||||
)
|
||||
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
|
||||
packed_weight_size = weights.numel() + scales_zeros.numel()
|
||||
return weights.new_empty(packed_weight_size, dtype=torch.float)
|
||||
|
||||
|
||||
@register_meta([aten._dyn_quant_matmul_4bit])
|
||||
def meta__dyn_quant_matmul_4bit(
|
||||
inp,
|
||||
packed_weights,
|
||||
block_size,
|
||||
in_features,
|
||||
out_features,
|
||||
):
|
||||
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||
torch._check(
|
||||
inp.dtype in [torch.float32],
|
||||
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||
)
|
||||
M = inp.size(0)
|
||||
return inp.new_empty(M, out_features, dtype=inp.dtype)
|
||||
|
||||
|
||||
@register_meta([aten._weight_int8pack_mm])
|
||||
def meta__weight_int8pack_mm(x, w, q_scales):
|
||||
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
|
||||
|
@ -62,7 +62,6 @@ from torch.backends import (
|
||||
cuda as cuda,
|
||||
cudnn as cudnn,
|
||||
cusparselt as cusparselt,
|
||||
kleidiai as kleidiai,
|
||||
mha as mha,
|
||||
mkl as mkl,
|
||||
mkldnn as mkldnn,
|
||||
|
@ -1,7 +0,0 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import torch
|
||||
|
||||
|
||||
def is_available():
|
||||
r"""Return whether PyTorch is built with KleidiAI support."""
|
||||
return torch._C._has_kleidiai
|
@ -1939,8 +1939,6 @@ Call this whenever a new thread is created in order to propagate values from
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
|
||||
ASSERT_TRUE(
|
||||
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
|
||||
|
||||
|
@ -18,8 +18,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||
|
@ -498,39 +498,6 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
|
||||
return out, scales_and_zeros
|
||||
|
||||
|
||||
def _group_quantize_tensor_symmetric(
|
||||
w, n_bit=4, groupsize=32
|
||||
):
|
||||
# W is of shape [K x N]
|
||||
# We transpose W as Quantization is applied on [N x K]
|
||||
w = w.transpose(0, 1).contiguous()
|
||||
assert w.dim() == 2
|
||||
assert groupsize > 1
|
||||
assert w.shape[-1] % groupsize == 0
|
||||
# Calculate scale and zeros
|
||||
to_quant = w.reshape(-1, groupsize)
|
||||
max_val = to_quant.abs().amax(dim=1, keepdim=True)
|
||||
eps = torch.finfo(max_val.dtype).eps
|
||||
max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7
|
||||
scales = max_val.clamp(min=eps) / max_int
|
||||
zeros = torch.zeros_like(scales)
|
||||
|
||||
# Quantize the weight
|
||||
scales = scales.to(torch.float32).reshape(w.shape[0], -1)
|
||||
zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
|
||||
scales = scales.reshape(-1, 1)
|
||||
zeros = zeros.reshape(-1, 1)
|
||||
max_int = 2**n_bit - 1
|
||||
w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
|
||||
# We pack 2 signed int4 values in unsigned uint8 container.
|
||||
# This reduces the weight size by half and improves load perf
|
||||
out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
|
||||
|
||||
scales_and_zeros = scales.squeeze().contiguous()
|
||||
|
||||
return out_uint8, scales_and_zeros
|
||||
|
||||
|
||||
def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
||||
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
|
||||
# default setup for affine quantization of activations
|
||||
@ -563,6 +530,7 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
||||
return quant, scales.to(x_dtype), zero_points
|
||||
|
||||
|
||||
|
||||
# QuantizationTestCase used as a base class for testing quantization on modules
|
||||
class QuantizationTestCase(TestCase):
|
||||
def setUp(self):
|
||||
|
@ -45,8 +45,6 @@ inductor_fallback_ops = {
|
||||
"aten.cummin.default",
|
||||
"aten.cumprod.default",
|
||||
"aten.cumsum.default",
|
||||
"aten._dyn_quant_matmul_4bit.default",
|
||||
"aten._dyn_quant_pack_4bit_weight.default",
|
||||
"aten._efficient_attention_backward.default",
|
||||
"aten._efficient_attention_forward.default",
|
||||
"aten._efficientzerotensor.default",
|
||||
|
Reference in New Issue
Block a user