[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)

Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.

2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.

3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).

4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights,  groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.

API Usage: https://github.com/pytorch/pytorch/issues/143289

Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode  : 40  t/s
2B Transformer model
Prefill : 747 t/s
Decode  : 80  t/s

Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s

OK

python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s

OK

python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s

Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
This commit is contained in:
Nikhil Gupta
2024-12-19 18:51:23 +00:00
committed by PyTorch MergeBot
parent c5ddf5dd90
commit 4b82251011
37 changed files with 1898 additions and 23 deletions

3
.gitmodules vendored
View File

@ -131,3 +131,6 @@
path = third_party/composable_kernel path = third_party/composable_kernel
url = https://github.com/ROCm/composable_kernel.git url = https://github.com/ROCm/composable_kernel.git
branch = develop branch = develop
[submodule "third_party/kleidiai"]
path = third_party/kleidiai
url = https://git.gitlab.arm.com/kleidi/kleidiai.git

View File

@ -254,6 +254,7 @@ filegroup(
# target that generates these sources... # target that generates these sources...
) )
# TODO: Enable support for KleidiAI bazel build
header_template_rule( header_template_rule(
name = "aten_src_ATen_config", name = "aten_src_ATen_config",
src = "aten/src/ATen/Config.h.in", src = "aten/src/ATen/Config.h.in",
@ -273,6 +274,7 @@ header_template_rule(
"@AT_PARALLEL_NATIVE@": "1", "@AT_PARALLEL_NATIVE@": "1",
"@AT_BLAS_F2C@": "0", "@AT_BLAS_F2C@": "0",
"@AT_BLAS_USE_CBLAS_DOT@": "1", "@AT_BLAS_USE_CBLAS_DOT@": "1",
"@AT_KLEIDIAI_ENABLED@": "0",
}, },
) )

View File

@ -377,6 +377,8 @@ cmake_dependent_option(
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF) cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
OFF "USE_CUDA" OFF) OFF "USE_CUDA" OFF)
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
"CPU_AARCH64" OFF)
option(USE_MIMALLOC "Use mimalloc" OFF) option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance # Enable third party mimalloc library to improve memory allocation performance
@ -418,6 +420,8 @@ endif()
if(WIN32) if(WIN32)
set(USE_TENSORPIPE OFF) set(USE_TENSORPIPE OFF)
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF") message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
set(USE_KLEIDIAI OFF)
message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF")
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT}) if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
find_library( find_library(
@ -667,6 +671,9 @@ if(ANDROID
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND") message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
set(BUILD_LAZY_TS_BACKEND OFF) set(BUILD_LAZY_TS_BACKEND OFF)
set(USE_KLEIDIAI OFF)
message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF")
# Set -ffunction-sections and -fdata-sections so that each method has its own # Set -ffunction-sections and -fdata-sections so that each method has its own
# text section. This allows the linker to remove unused section when the flag # text section. This allows the linker to remove unused section when the flag
# -Wl,-gc-sections is provided at link time. # -Wl,-gc-sections is provided at link time.

View File

@ -309,6 +309,12 @@ local_repository(
path = "third_party/gemmlowp/gemmlowp", path = "third_party/gemmlowp/gemmlowp",
) )
local_repository(
name = "kleidiai",
path = "third_party/kleidiai",
repo_mapping = {"@com_google_googletest": "@com_google_benchmark"},
)
### Unused repos start ### Unused repos start
# `unused` repos are defined to hide bazel files from submodules of submodules. # `unused` repos are defined to hide bazel files from submodules of submodules.

View File

@ -199,6 +199,10 @@ endif()
# XNNPACK # XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp") file(GLOB native_xnnpack "native/xnnpack/*.cpp")
# KLEIDIAI
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
file(GLOB native_kleidiai_h "native/kleidiai/*.h")
# Add files needed from jit folders # Add files needed from jit folders
append_filelist("jit_core_headers" ATen_CORE_HEADERS) append_filelist("jit_core_headers" ATen_CORE_HEADERS)
append_filelist("jit_core_sources" ATen_CORE_SRCS) append_filelist("jit_core_sources" ATen_CORE_SRCS)
@ -228,6 +232,10 @@ endif()
if(AT_MKL_ENABLED) if(AT_MKL_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
endif() endif()
if(AT_KLEIDIAI_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
endif()
if(AT_MKLDNN_ENABLED) if(AT_MKLDNN_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp}) set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
endif() endif()
@ -611,7 +619,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
if(NOT INTERN_BUILD_MOBILE) if(NOT INTERN_BUILD_MOBILE)
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
# Metal # Metal
if(USE_PYTORCH_METAL_EXPORT) if(USE_PYTORCH_METAL_EXPORT)
# Add files needed from exporting metal models(optimized_for_mobile) # Add files needed from exporting metal models(optimized_for_mobile)

View File

@ -19,3 +19,4 @@
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
#define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_F2C() @AT_BLAS_F2C@
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@

View File

@ -376,6 +376,10 @@ bool Context::hasMKLDNN() {
#endif #endif
} }
bool Context::hasKleidiAI() {
return AT_KLEIDIAI_ENABLED();
}
bool Context::hasOpenMP() { bool Context::hasOpenMP() {
#ifdef _OPENMP #ifdef _OPENMP
return true; return true;

View File

@ -118,6 +118,7 @@ class TORCH_API Context {
static bool hasOpenMP(); static bool hasOpenMP();
static bool hasMKL(); static bool hasMKL();
static bool hasKleidiAI();
static bool hasLAPACK(); static bool hasLAPACK();
static bool hasMKLDNN(); static bool hasMKLDNN();
static bool hasMAGMA() { static bool hasMAGMA() {
@ -538,6 +539,10 @@ inline bool hasMKL() {
return globalContext().hasMKL(); return globalContext().hasMKL();
} }
inline bool hasKleidiAI() {
return globalContext().hasKleidiAI();
}
inline bool hasLAPACK() { inline bool hasLAPACK() {
return globalContext().hasLAPACK(); return globalContext().hasLAPACK();
} }

View File

@ -33,6 +33,8 @@
#include <ATen/ops/_addmm_activation_native.h> #include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_compute_linear_combination_native.h> #include <ATen/ops/_compute_linear_combination_native.h>
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h> #include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
#include <ATen/ops/_dyn_quant_matmul_4bit_native.h>
#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h>
#include <ATen/ops/_int_mm_native.h> #include <ATen/ops/_int_mm_native.h>
#include <ATen/ops/_linalg_check_errors.h> #include <ATen/ops/_linalg_check_errors.h>
#include <ATen/ops/_linalg_det.h> #include <ATen/ops/_linalg_det.h>
@ -3429,6 +3431,8 @@ Tensor kron(const Tensor& self, const Tensor& other) {
DEFINE_DISPATCH(weight_to_int4pack_stub); DEFINE_DISPATCH(weight_to_int4pack_stub);
DEFINE_DISPATCH(int4pack_mm_stub); DEFINE_DISPATCH(int4pack_mm_stub);
DEFINE_DISPATCH(int8pack_mm_stub); DEFINE_DISPATCH(int8pack_mm_stub);
DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub);
DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub);
Tensor _convert_weight_to_int4pack_cpu( Tensor _convert_weight_to_int4pack_cpu(
const Tensor& in, const Tensor& in,
@ -3492,6 +3496,69 @@ Tensor _weight_int4pack_mm_cpu(
return C; return C;
} }
Tensor _dyn_quant_pack_4bit_weight_cpu(
const Tensor& weights,
const Tensor& scales_zeros,
const std::optional<Tensor>& bias,
const int64_t block_size,
const int64_t in_features,
const int64_t out_features) {
TORCH_CHECK(
weights.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),
__func__,
": Group size should be multiple of 32, in_features [",
in_features,
"]. Provided ",
block_size);
Tensor packed_weights =
at::empty(weights.sizes(), weights.options().dtype(at::kByte));
dyn_quant_pack_4bit_weight_stub(
kCPU,
packed_weights,
weights,
scales_zeros,
bias,
out_features,
in_features,
block_size);
return packed_weights;
}
Tensor _dyn_quant_matmul_4bit_cpu(
const Tensor& inp,
const Tensor& packed_weights,
const int64_t block_size,
const int64_t in_features,
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
__func__,
" : expect input to be 32-bit float tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),
__func__,
": Group size should be multiple of 32, in_features [",
in_features,
"]. Provided ",
block_size);
auto output = at::empty({M, out_features}, inp.options());
dyn_quant_matmul_4bit_stub(
kCPU,
output,
inp,
packed_weights,
M,
out_features,
in_features,
block_size);
return output;
}
Tensor _weight_int8pack_mm_cpu( Tensor _weight_int8pack_mm_cpu(
const Tensor& A, const Tensor& A,
const Tensor& B, const Tensor& B,

View File

@ -8,8 +8,19 @@
#include <ATen/cpu/vec/vec.h> #include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h> #include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h> #include <ATen/native/cpu/utils.h>
#include <c10/util/irange.h>
#include <c10/util/Unroll.h> #include <c10/util/Unroll.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/cat.h>
#endif
#if AT_KLEIDIAI_ENABLED()
#include <ATen/native/kleidiai/kai_kernels.h>
#include <cpuinfo.h>
#endif
#if (defined(_WIN32) || defined(_WIN64)) #if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict #define RESTRICT __restrict
@ -762,10 +773,457 @@ void int4pack_mm_kernel(
} }
} }
#if AT_KLEIDIAI_ENABLED()
bool can_use_kleidiai(
const at::Tensor& scales_zeros,
const int64_t K,
const int64_t block_size) {
bool ret = false;
if (cpuinfo_has_arm_neon_dot()) {
// The Groupwise kernel requires BFloat16 Scales and Channelwise kernel
// requires Float32 Scales. If not provided, we will use fallback
// implementation.
if ((block_size == K && scales_zeros.dtype() == at::kFloat) ||
((block_size < K && !(block_size % 32) && !(K % block_size)) &&
scales_zeros.dtype() == at::kBFloat16)) {
ret = true;
}
}
return ret;
}
#endif
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
* the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2].
*
* For KleidiAI weight packing, the scales, biases, and Int4 quantized
* weights are packed into a single `packed_weights` structure, optimized for
* Arm instructions.
*
* In the fallback reference kernel, no special packing is required for
* Int4 quantized weights.
*
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
*/
void dyn_quant_pack_4bit_weight_kernel(
Tensor& packed_weights,
const Tensor& weights,
const Tensor& scales_zeros,
const std::optional<Tensor>& bias,
const int64_t N,
const int64_t K,
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
size_t m,
size_t n,
size_t k,
const float* lhs_f32,
const uint8_t* rhs_qs4cx,
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
// Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
const size_t lhs_qa8dx_stride = k;
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
float max0 = -FLT_MAX;
float min0 = FLT_MAX;
// Find min/max for each channel
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
max0 = (std::max)(src0_0, max0);
min0 = (std::min)(src0_0, min0);
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = (std::min)(0.0f, min0);
const float rmax0 = (std::max)(0.0f, max0);
const float scale0 =
rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
// Reciprocal to quantize
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
const float descaled_min0 = rmin0 * scale0;
const float descaled_max0 = rmax0 * scale0;
const float zero_point_from_min_error0 = qmin + descaled_min0;
const float zero_point_from_max_error0 = qmax + descaled_max0;
float zero_point0 =
zero_point_from_min_error0 + zero_point_from_max_error0 > 0
? qmin - descaled_min0
: qmax - descaled_max0;
zero_point0 = (std::max)(zero_point0, qmin);
zero_point0 = (std::min)(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride;
// LHS offset at the beginning of the row
*((float*)(dst_ptr)) = recip_scale0;
dst_ptr += sizeof(float);
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
dst_ptr += sizeof(int32_t);
// Quantize the channels
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
// Scale the values
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = (std::max)(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = (std::min)(v0_s32, static_cast<int32_t>(INT8_MAX));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
}
};
// Dynamically Quantize the float32 input to 8 bit assymetric
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
const size_t lhs_stride =
k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2);
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride;
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
// Main f32 accumulator
int32_t iacc = 0;
const int8_t* lhs_ptr = lhs_ptr_start;
const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
// Get the LHS quantization parameters stored at the
// beginning of each row
const float lhs_scale = *(const float*)lhs_ptr;
lhs_ptr += sizeof(float);
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
lhs_ptr += sizeof(int32_t);
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
// Get the LHS values
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
// Get the RHS values
const uint8_t rhs_byte = rhs_ptr[0];
// Unpack the RHS values
int32_t rhs_v0 = 0;
if ((k_idx % 2) == 0) {
rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
} else {
rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
}
iacc += lhs_v0 * rhs_v0;
iacc += lhs_offset * rhs_v0;
lhs_ptr += 1;
// Increment only when k_idx is not a multiple of 2
rhs_ptr += k_idx % 2;
}
// Get the RHS scale
const float rhs_scale = rhs_scales_f32[n_idx];
float main_acc = iacc * rhs_scale;
main_acc = main_acc * lhs_scale;
// Clamp (min-max) operation
main_acc = (std::max)(main_acc, scalar_min);
main_acc = (std::min)(main_acc, scalar_max);
dst_f32[0] = main_acc;
dst_f32 += 1;
}
}
};
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
size_t m,
size_t n,
size_t k,
size_t bl,
const float* lhs_f32,
const uint8_t* rhs_qs4c32,
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
const float* src_ptr = lhs_f32 + row_idx * k;
float max0 = -FLT_MAX;
float min0 = FLT_MAX;
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
max0 = (std::max)(src0_0, max0);
min0 = (std::min)(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = (std::min)(0.0f, min0);
const float rmax0 = (std::max)(0.0f, max0);
const float scale0 =
(rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
const float descaled_min0 = rmin0 * scale0;
const float descaled_max0 = rmax0 * scale0;
float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0)
? qmin - descaled_min0
: qmax - descaled_max0;
zero_point0 = (std::max)(zero_point0, qmin);
zero_point0 = (std::min)(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride;
*((float*)(dst_ptr)) = recip_scale0;
dst_ptr += sizeof(float);
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
dst_ptr += sizeof(int32_t);
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = (std::max)(
(std::min)(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
}
};
auto lhs_qa8dx_buffer = std::make_unique<int8_t[]>(
m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS
int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
// Quantize and pack LHS
lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx);
const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl;
const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2;
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride;
for (size_t col_idx = 0; col_idx < n; ++col_idx) {
float main_acc = 0.0f;
const int8_t* lhs_ptr = lhs_ptr_start;
const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride;
const float lhs_scale = *(const float*)lhs_ptr;
lhs_ptr += sizeof(float);
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
lhs_ptr += sizeof(int32_t);
for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) {
const float rhs_scale =
rhs_scales_fp32[block_idx + col_idx * num_blocks_row];
int32_t iacc = 0;
for (size_t i = 0; i < bl; ++i) {
const size_t k_idx = block_idx * bl + i;
if (k_idx >= k) {
break;
}
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
const uint8_t rhs_byte = rhs_ptr[0];
int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8)
: (((int32_t)(rhs_byte >> 4)) - 8);
iacc += lhs_v0 * rhs_v0;
iacc += lhs_offset * rhs_v0;
lhs_ptr += 1;
rhs_ptr += (k_idx % 2);
}
main_acc += iacc * rhs_scale;
}
main_acc = main_acc * lhs_scale;
main_acc = (std::max)(main_acc, scalar_min);
main_acc = (std::min)(main_acc, scalar_max);
dst_f32[0] = main_acc;
dst_f32 += 1;
}
}
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
const Tensor& inp,
const Tensor& packed_weights,
const int64_t M,
const int64_t N,
const int64_t K,
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface intenally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
}
}
}
} // anonymous namespace } // anonymous namespace
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel) ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel) ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel)
} // at::native } // at::native
C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP()

View File

@ -5,12 +5,34 @@
namespace at::native { namespace at::native {
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&); using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); using int4pack_mm_fn =
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
using int8pack_mm_fn =
void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
using dyn_quant_pack_4bit_weight_fn = void (*)(
Tensor&,
const Tensor&,
const Tensor&,
const std::optional<Tensor>& bias,
const int64_t,
const int64_t,
const int64_t);
using dyn_quant_matmul_4bit_fn = void (*)(
const Tensor&,
const Tensor&,
const Tensor&,
const int64_t,
const int64_t,
const int64_t,
const int64_t);
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub) DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub) DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub) DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
DECLARE_DISPATCH(
dyn_quant_pack_4bit_weight_fn,
dyn_quant_pack_4bit_weight_stub);
DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub);
} // namespace at::native } // namespace at::native

View File

@ -0,0 +1,440 @@
#include <ATen/native/kleidiai/kai_kernels.h>
#include <ATen/native/kleidiai/kai_pack.h>
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
#include <ATen/Parallel.h>
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <unordered_map>
#if AT_KLEIDIAI_ENABLED()
#include <cpuinfo.h>
namespace at::native::kleidiai {
void kai_pack_int4_rhs(
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
const int64_t rhs_stride = kai_roundup(k, 2) / 2;
const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
params.scale_dt = kai_datatype::kai_dt_bf16;
kai_pack_rhs_groupwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4c32p>(
kernel_packet,
weight_packed,
weight,
scales,
bias,
n,
k,
bl,
rhs_stride,
scale_stride);
}
}
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(
n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16);
}
return packed_size;
}
static void matmul_channelwise(
kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet,
size_t m_increment,
size_t m_start,
size_t m_per_thread,
size_t n_start,
size_t n_per_thread,
size_t n,
size_t k,
size_t mr,
size_t nr,
size_t kr,
size_t sr,
size_t dst_stride,
size_t lhs_stride,
uint8_t* lhs_native_mtx_f32,
uint8_t* lhs_packed_mtx_qa8dx,
uint8_t* rhs_packed_mtx_qs4cx,
uint8_t* dst_act_mtx_f32) {
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
const float* src_ptr =
(const float*)(lhs_native_mtx_f32 +
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
m_start + m0, lhs_stride));
void* lhs_packed_ptr =
(void*)(lhs_packed_mtx_qa8dx +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
0, k, mr, kr, sr));
const void* rhs_packed_ptr =
(const void*)((const char*)rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(n_start, k));
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
kernel_packet.ukernel.get_dst_offset(
m_start + m0, n_start, dst_stride));
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
// Run Matmul on Int4 packed weights and Quantized Packed Input
kernel_packet.ukernel.run_matmul(
m_increment,
n_per_thread,
k,
lhs_packed_ptr,
rhs_packed_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
}
static void matmul_groupwise(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
const size_t m,
const size_t num_n_per_thread,
const size_t n_start,
const size_t k,
const size_t bl,
const size_t dst_stride,
const void* lhs_ptr,
uint8_t* rhs_packed,
uint8_t* dst_data) {
const size_t rhs_packed_offset =
ukernel.get_rhs_packed_offset(n_start, k, bl);
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
// Run Matmul on Int4 packed weights and Quantized Packed Input
ukernel.run_matmul(
m,
num_n_per_thread,
k,
bl,
lhs_ptr,
rhs_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
struct ThreadDivision {
int64_t num_threads_x;
int64_t num_threads_y;
bool use_gemm; // True if GEMM is selected, false if GEMV is used
};
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n - (n >> 1);
}
inline static void adjust_max_threads(int64_t& max_threads) {
// We would not like to round down to nearest power of 2 always
// There can be possible thread split combination between powers of 2 for odd
// shapes
// TODO:: Decide better strategy based on hint of input and weight shapes
max_threads = round_down_to_power_of_2(max_threads);
}
static std::pair<int64_t, int64_t> split_2d(const int64_t max_threads) {
int64_t sqrt_threads = std::sqrt(max_threads);
for (int64_t i = sqrt_threads; i >= 1; --i) {
if (max_threads % i == 0) {
return {i, max_threads / i};
}
}
return {1, max_threads}; // Theres still a possibility of 1D blocking when
// calling GEMM kernel
}
inline static ThreadDivision get_thread_division(
int64_t max_threads,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t gemm_m_step,
const int64_t gemm_n_step,
const int64_t gemv_m_step,
const int64_t gemv_n_step) {
adjust_max_threads(max_threads);
ThreadDivision division{1, 1, false};
// Split threads 2D for GEMM
if (m % gemm_m_step == 0 && n % gemm_n_step == 0) {
while (max_threads > 0) {
auto [num_thread_y, num_thread_x] = split_2d(max_threads);
if (m % num_thread_y == 0 && n % num_thread_x == 0) {
int64_t m_per_thread = m / num_thread_y;
int64_t n_per_thread = n / num_thread_x;
if (m_per_thread % gemm_m_step == 0 &&
n_per_thread % gemm_n_step == 0) {
division = {num_thread_x, num_thread_y, true};
return division;
}
}
max_threads -= 2;
}
}
// Split threads 1D for GEMV
if (n % gemv_n_step == 0) {
for (; max_threads > 0; max_threads -= 2) {
if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) {
division.num_threads_x = max_threads;
return division;
}
}
}
return division;
}
static void kai_quant_pack_lhs_int4_mm_groupwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
kai_kernel_id id = kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
if (cpuinfo_has_arm_i8mm() && m > 1) {
id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
}
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
const auto& ukernel = kernel_packet.ukernel;
const size_t mr = ukernel.get_mr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
const size_t n_step = ukernel.get_n_step();
int64_t total_threads = at::get_num_threads();
int64_t num_threads_x = 1;
adjust_max_threads(total_threads);
// Split threads 1D only for now
if (n % n_step == 0) {
for (; total_threads > 0; total_threads -= 2) {
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
num_threads_x = total_threads;
break;
}
}
}
const size_t num_n_per_thread = n / num_threads_x;
const size_t dst_stride = n * sizeof(float);
float* lhs = reinterpret_cast<float*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m,
k,
mr,
kr,
sr,
0,
(const float*)lhs,
k * sizeof(float),
(void*)lhs_packed.get());
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
for (const auto x : c10::irange(begin, end)) {
matmul_groupwise(
std::ref(ukernel),
m,
num_n_per_thread,
x * num_n_per_thread,
k,
bl,
dst_stride,
lhs_packed.get(),
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
static void kai_quant_pack_lhs_int4_mm_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm;
kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod;
// Get the total number of threads available and choose GEMM or GEMV steps
const int64_t total_threads = at::get_num_threads();
auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm()) {
gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id);
}
auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
// Retrieve m_step and n_step values from GEMM and GEMV kernels
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
// Determine threading and kernel type
ThreadDivision division = get_thread_division(
total_threads,
m,
n,
k,
gemm_m_step,
gemm_n_step,
gemv_m_step,
gemv_n_step);
// Select appropriate kernel packet based on the chosen kernel type
auto& kernel_packet =
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
// Thread blocking parameters
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t nr = kernel_packet.ukernel.get_nr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t m_increment = kernel_packet.ukernel.get_m_step();
const size_t n_per_thread = n / division.num_threads_x;
const size_t m_per_thread = m / division.num_threads_y;
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
const size_t dst_stride = n * sizeof(float);
const size_t lhs_stride = k * sizeof(float);
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
uint8_t* lhs_packed_base = lhs_packed.get();
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
size_t y = i / division.num_threads_x;
size_t x = i % division.num_threads_x;
uint8_t* lhs_packed_ptr =
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
matmul_channelwise(
std::ref(kernel_packet),
m_increment,
y * m_per_thread,
m_per_thread,
x * n_per_thread,
n_per_thread,
n,
k,
mr,
nr,
kr,
sr,
dst_stride,
lhs_stride,
lhs_native_mtx_f32,
lhs_packed_ptr,
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}
}
} // namespace at::native::kleidiai
#endif

View File

@ -0,0 +1,42 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
/**
* @brief Rearranges the quantized weight to support kleidiai inference
* @param bl Groupsize for quantization should be multiple of 32
*/
void kai_pack_int4_rhs(
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl);
/**
* @brief Outputs the buffer size for the packed weights
* @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise
*/
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
*/
void kai_quant_pack_lhs_int4_mm(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl);
} // namespace at::native::kleidiai
#endif

View File

@ -0,0 +1,106 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <torch/library.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
template <typename T>
void kai_pack_rhs_groupwise_int4(
T& kernel,
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl,
const int64_t rhs_stride,
const int64_t scale_stride) {
const auto& ukernel = kernel.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
auto scales_data = scales.const_data_ptr();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
}
if (scales_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
/*num_groups=*/1,
n,
k,
nr,
kr,
sr,
bl,
(const uint8_t*)(weight_data),
rhs_stride,
bias_ptr,
scales_data,
scale_stride,
weight_packed_data,
0,
&params);
}
template <typename T>
void kai_pack_rhs_channelwise_int4(
T& kernel,
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k) {
const auto& ukernel = kernel.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
}
if (scales_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
/*num_groups=*/1,
n,
k,
nr,
kr,
sr,
(const uint8_t*)(weight_data),
(const float*)(bias_ptr),
(const float*)(scales_data),
weight_packed_data,
0,
&params);
}
} // namespace at::native::kleidiai
#endif

View File

@ -0,0 +1,72 @@
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
// Kernel Mapping - Groupwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4c32p> groupwise_8bit_4bit_kernels =
{{kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}};
kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
kai_kernel_id id) {
return groupwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4cxp> channelwise_8bit_4bit_kernels =
{{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}};
kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -0,0 +1,144 @@
#pragma once
#include <ATen/Config.h>
#include <unordered_map>
#if AT_KLEIDIAI_ENABLED()
#include <kai_common.h>
#include <kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
#include <kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
};
// Channelwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const float* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
enum kai_datatype scale_dt);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const float* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
const uint8_t* rhs,
size_t rhs_stride,
const float* bias,
const void* scale,
size_t scale_stride,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
const kai_kernel_id id);
} // namespace at::native::kleidiai
#endif

View File

@ -4177,6 +4177,14 @@
dispatch: dispatch:
CPU: _weight_int4pack_mm_cpu CPU: _weight_int4pack_mm_cpu
- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor
dispatch:
CPU: _dyn_quant_pack_4bit_weight_cpu
- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor
dispatch:
CPU: _dyn_quant_matmul_4bit_cpu
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
dispatch: dispatch:
CPU: _weight_int8pack_mm_cpu CPU: _weight_int8pack_mm_cpu

View File

@ -1070,6 +1070,7 @@ def define_buck_targets(
], ],
) )
# TODO: Enable support for KleidiAI bazel build
# @lint-ignore BUCKLINT # @lint-ignore BUCKLINT
fb_native.genrule( fb_native.genrule(
name = "generate_aten_config", name = "generate_aten_config",
@ -1122,6 +1123,9 @@ def define_buck_targets(
"--replace", "--replace",
"@AT_BLAS_USE_CBLAS_DOT@", "@AT_BLAS_USE_CBLAS_DOT@",
"AT_BLAS_USE_CBLAS_DOT_FBXPLAT", "AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
"--replace",
"@AT_KLEIDIAI_ENABLED@",
"0",
]), ]),
outs = { outs = {
"Config.h": ["Config.h"], "Config.h": ["Config.h"],

View File

@ -152,6 +152,7 @@ endif()
set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ACL_ENABLED 0)
set(AT_MKLDNN_ENABLED 0) set(AT_MKLDNN_ENABLED 0)
set(AT_MKL_ENABLED 0) set(AT_MKL_ENABLED 0)
set(AT_KLEIDIAI_ENABLED 0)
# setting default preferred BLAS options if not already present. # setting default preferred BLAS options if not already present.
if(NOT INTERN_BUILD_MOBILE) if(NOT INTERN_BUILD_MOBILE)
set(BLAS "MKL" CACHE STRING "Selected BLAS library") set(BLAS "MKL" CACHE STRING "Selected BLAS library")
@ -1480,6 +1481,35 @@ if(NOT INTERN_BUILD_MOBILE)
message("disabling MKLDNN because USE_MKLDNN is not set") message("disabling MKLDNN because USE_MKLDNN is not set")
endif() endif()
if(USE_KLEIDIAI)
if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
set(AT_KLEIDIAI_ENABLED 1)
set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests
set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai")
add_subdirectory(${KLEIDIAI_SRC})
set(KLEIDIAI_INCLUDE_DIRS
${KLEIDIAI_SRC}/
${KLEIDIAI_SRC}/kai/
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/
)
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai)
# Recover build options.
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
endif()
if(UNIX AND NOT APPLE) if(UNIX AND NOT APPLE)
include(CheckLibraryExists) include(CheckLibraryExists)
# https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830 # https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830

View File

@ -153,6 +153,9 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}") message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}") message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
endif() endif()
if(${USE_KLEIDIAI})
message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}")
endif()
message(STATUS " USE_UCC : ${USE_UCC}") message(STATUS " USE_UCC : ${USE_UCC}")
if(${USE_UCC}) if(${USE_UCC})
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")

View File

@ -97,6 +97,10 @@ else()
append_torchlib_if_found(microkernels-prod) append_torchlib_if_found(microkernels-prod)
endif() endif()
if(@USE_KLEIDIAI@)
append_torchlib_if_found(kleidiai)
endif()
append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc) append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
append_torchlib_if_found(onnx onnx_proto) append_torchlib_if_found(onnx onnx_proto)

View File

@ -203,6 +203,7 @@ torch.backends.openmp
.. add anything to the rendered page for now. .. add anything to the rendered page for now.
.. py:module:: torch.backends.quantized .. py:module:: torch.backends.quantized
.. py:module:: torch.backends.xnnpack .. py:module:: torch.backends.xnnpack
.. py:module:: torch.backends.kleidiai
torch.backends.opt_einsum torch.backends.opt_einsum

View File

@ -1221,6 +1221,7 @@ def main():
"include/ATen/native/cuda/*.cuh", "include/ATen/native/cuda/*.cuh",
"include/ATen/native/hip/*.h", "include/ATen/native/hip/*.h",
"include/ATen/native/hip/*.cuh", "include/ATen/native/hip/*.cuh",
"include/ATen/native/kleidiai/*.h",
"include/ATen/native/mps/*.h", "include/ATen/native/mps/*.h",
"include/ATen/native/mkldnn/xpu/*.h", "include/ATen/native/mkldnn/xpu/*.h",
"include/ATen/native/mkldnn/xpu/detail/*.h", "include/ATen/native/mkldnn/xpu/detail/*.h",

View File

@ -92,6 +92,8 @@ aten::_dimI
aten::_dimV aten::_dimV
aten::_dirichlet_grad aten::_dirichlet_grad
aten::_dirichlet_grad.out aten::_dirichlet_grad.out
aten::_dyn_quant_matmul_4bit
aten::_dyn_quant_pack_4bit_weight
aten::_efficient_attention_backward aten::_efficient_attention_backward
aten::_efficient_attention_forward aten::_efficient_attention_forward
aten::_efficientzerotensor aten::_efficientzerotensor

View File

@ -76,6 +76,7 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import all_types, get_all_dtypes from torch.testing._internal.common_dtype import all_types, get_all_dtypes
from torch.testing._internal.common_quantization import ( from torch.testing._internal.common_quantization import (
_dynamically_quantize_per_channel, _dynamically_quantize_per_channel,
_group_quantize_tensor_symmetric,
) )
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
DeterministicGuard, DeterministicGuard,
@ -2224,6 +2225,85 @@ class CommonTemplate:
b_int8pack, b_scales = convert_weight_to_int8pack(b) b_int8pack, b_scales = convert_weight_to_int8pack(b)
self.common(fn, (a, b_int8pack, b_scales, c)) self.common(fn, (a, b_int8pack, b_scales, c))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight(self):
q_group = 32
k = 128
n = 128
torch.manual_seed(1)
b = torch.rand((k, n), dtype=torch.float32)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(b, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
return b_int4pack
self.common(fn, (b, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit(self):
q_group = 32
m = 32
k = 128
n = 128
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.float32)
b = torch.rand((k, n), dtype=torch.float32)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(a, q_group, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
res = torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
return res
self.common(fn, (a, q_group, in_features, out_features))
def test_expanded_reduction(self): def test_expanded_reduction(self):
def fn(x, y): def fn(x, y):
z = x * y z = x * y

View File

@ -34,7 +34,8 @@ from torch.testing._internal.common_dtype import (
) )
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \
_group_quantize_tensor_symmetric
from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.distributions.binomial import Binomial from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum import torch.backends.opt_einsum as opt_einsum
@ -909,7 +910,6 @@ class TestLinalg(TestCase):
torch.randn((3, 52, 52), device=device, dtype=dtype), torch.randn((3, 52, 52), device=device, dtype=dtype),
torch.randn((4, 2, 26, 26), device=device, dtype=dtype)) torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
ops = (torch.det, torch.Tensor.det, ops = (torch.det, torch.Tensor.det,
torch.linalg.det) torch.linalg.det)
for t in tensors: for t in tensors:
@ -1438,7 +1438,6 @@ class TestLinalg(TestCase):
continue continue
run_test_case(make_arg(shape), ord, dim, keepdim) run_test_case(make_arg(shape), ord, dim, keepdim)
@onlyCUDA @onlyCUDA
@dtypes(torch.bfloat16, torch.float16) @dtypes(torch.bfloat16, torch.float16)
def test_norm_fused_type_promotion(self, device, dtype): def test_norm_fused_type_promotion(self, device, dtype):
@ -4344,7 +4343,6 @@ class TestLinalg(TestCase):
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
upper, unitriangular, transpose) upper, unitriangular, transpose)
@slowTest @slowTest
@skipCUDAIfNoMagma @skipCUDAIfNoMagma
@skipCPUIfNoLapack @skipCPUIfNoLapack
@ -4465,7 +4463,6 @@ class TestLinalg(TestCase):
self.assertTrue("An output with one or more elements was resized" in str(w[0].message)) self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
self.assertTrue("An output with one or more elements was resized" in str(w[1].message)) self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
def check_single_matmul(self, x, y): def check_single_matmul(self, x, y):
def assertEqual(answer, expected): def assertEqual(answer, expected):
@ -5701,7 +5698,6 @@ class TestLinalg(TestCase):
else: else:
self.assertEqual(B_, X_ @ A) self.assertEqual(B_, X_ @ A)
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
# Non pivoting just implemented for CUDA # Non pivoting just implemented for CUDA
@ -5734,7 +5730,6 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
f(torch.empty(1, 2, 2), pivot=False) f(torch.empty(1, 2, 2), pivot=False)
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
@skipCUDAIfNoMagmaAndNoCusolver @skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack @skipCPUIfNoLapack
@ -5762,7 +5757,6 @@ class TestLinalg(TestCase):
for b, n in shapes: for b, n in shapes:
yield make_arg((b, n, n)), make_arg((b, n, rhs)) yield make_arg((b, n, n)), make_arg((b, n, rhs))
for A, B in gen_matrices(): for A, B in gen_matrices():
LU, pivots = torch.linalg.lu_factor(A) LU, pivots = torch.linalg.lu_factor(A)
for backend in backends: for backend in backends:
@ -5777,7 +5771,6 @@ class TestLinalg(TestCase):
else: else:
self.assertEqual(B_left, X @ A_adj) self.assertEqual(B_left, X @ A_adj)
@onlyCPU @onlyCPU
@dtypes(*floating_and_complex_types()) @dtypes(*floating_and_complex_types())
def test_linalg_lu_cpu_errors(self, device, dtype): def test_linalg_lu_cpu_errors(self, device, dtype):
@ -5818,7 +5811,6 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
torch.lu_unpack(LU, pivots) torch.lu_unpack(LU, pivots)
# Rectangular tests # Rectangular tests
sample = torch.randn(2, 3, 5, device=device, dtype=dtype) sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
B = torch.randn(2, 3, 5, device=device, dtype=dtype) B = torch.randn(2, 3, 5, device=device, dtype=dtype)
@ -5835,7 +5827,6 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
torch.lu_unpack(LU, pivots) torch.lu_unpack(LU, pivots)
@skipCPUIfNoLapack @skipCPUIfNoLapack
@skipCUDAIfNoMagma @skipCUDAIfNoMagma
@dtypes(torch.double) @dtypes(torch.double)
@ -6444,7 +6435,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
self.assertEqual(out, y_ref) self.assertEqual(out, y_ref)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA @onlyCUDA
def test_matmul_45724(self, device): def test_matmul_45724(self, device):
@ -6617,7 +6607,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
torch._int_mm(a_int8, b_int8, out=c_int32_result) torch._int_mm(a_int8, b_int8, out=c_int32_result)
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ -6680,7 +6669,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean() mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05) self.assertTrue(mean_err < 0.05)
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ -6733,6 +6721,168 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean() mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05) self.assertTrue(mean_err < 0.05)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("k", [64, 256])
@parametrize("n", [32, 48, 64, 128])
def test__dyn_quant_pack_4bit_weight(self, device, k, n):
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
# Weight shape is [K x N]
if self.device_type == "cuda":
self.skipTest("CUDA Backend is unsupported")
torch.manual_seed(1)
block_size = 32
b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b.size(0)
out_features = b.size(1)
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=block_size
)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
)
b_int4pack_meta = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
)
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
q_group = 32
torch.manual_seed(1)
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
in_features = b_float32.size(0)
out_features = b_float32.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def dyn_quant_matmul_4bit(
a, b_int4pack, q_group, in_features, out_features
):
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
b_float32, in_features, out_features
)
dtypes = [torch.float32]
for dtype in dtypes:
a = a_float32.to(dtype=dtype)
b = b_float32.to(dtype=dtype)
ref = torch.mm(a, b)
res = dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
min=torch.finfo(ref.dtype).eps
)
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
self.assertTrue(
all_elements_within_threshold, "Some elements have error >= 0.06"
)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
q_group = 32
torch.manual_seed(1)
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
in_features = b_float32.size(0)
out_features = b_float32.size(1)
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b_float32, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
@torch.compile
def dyn_quant_matmul_4bit(
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
):
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
res = dyn_quant_matmul_4bit(
a_float32,
b_uint8,
b_scales_and_zeros,
q_group,
in_features,
out_features,
)
ref = torch.mm(a_float32, b_float32)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
min=torch.finfo(ref.dtype).eps
)
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
self.assertTrue(
all_elements_within_threshold, "Some elements have error >= 0.06"
)
@onlyCPU @onlyCPU
@parametrize("m", [32, 64]) @parametrize("m", [32, 64])
@parametrize("k", [32, 64]) @parametrize("k", [32, 64])
@ -8664,8 +8814,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(ck_out, cpu_out) self.assertEqual(ck_out, cpu_out)
def test_permute_matmul(self): def test_permute_matmul(self):
a = torch.ones([2, 5, 24, 24]) a = torch.ones([2, 5, 24, 24])
b = torch.ones([3, 2, 5, 24, 24]) b = torch.ones([3, 2, 5, 24, 24])
@ -8745,7 +8893,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
ref = alpha * A @ B + beta * C ref = alpha * A @ B + beta * C
self.assertEqual(rc, ref) self.assertEqual(rc, ref)
@dtypes(torch.float, torch.double) @dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4}) @precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype): def test_1_sized_with_0_strided(self, device, dtype):

1
third_party/kleidiai vendored Submodule

Submodule third_party/kleidiai added at 202603f38a

View File

@ -1299,6 +1299,7 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a
has_openmp: _bool has_openmp: _bool
has_mkl: _bool has_mkl: _bool
_has_kleidiai: _bool
_has_mps: _bool _has_mps: _bool
has_lapack: _bool has_lapack: _bool
_has_cuda: _bool _has_cuda: _bool

View File

@ -1372,6 +1372,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._dim_arange", "torch._dim_arange",
"torch._dirichlet_grad", "torch._dirichlet_grad",
"torch._disable_functionalization", "torch._disable_functionalization",
"torch._dyn_quant_matmul_4bit",
"torch._dyn_quant_pack_4bit_weight",
"torch._efficientzerotensor", "torch._efficientzerotensor",
"torch._embedding_bag_forward_only", "torch._embedding_bag_forward_only",
"torch._embedding_bag", "torch._embedding_bag",

View File

@ -94,3 +94,6 @@ def register_woq_mm_ops() -> None:
return autotune_select_algorithm( return autotune_select_algorithm(
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
) )
lowering.make_fallback(aten._dyn_quant_matmul_4bit)
lowering.make_fallback(aten._dyn_quant_pack_4bit_weight)

View File

@ -3344,6 +3344,161 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
def kai_roundup(a: int, b: int) -> int:
return ((a + b - 1) // b) * b
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
if n_bits == 4:
if groupsize == K: # channelwise
# dotprod params only [1x8x32_neon_dotprod]
kai_nr = 8
kai_kr = 16
kai_sr = 2
kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
kai_num_bytes_bias = 4 # sizeof(float)
def kai_k_roundedup(k, kr, sr):
# Since we pack a float and int32 value at the end of the row,
# we must make sure that k is a multiple of 4 for alignment
kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
return kai_roundup(k, kr_sr_roundedup4)
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
k, nr, kr, sr
):
k_internal = kai_k_roundedup(k, kr, sr)
assert (k_internal % 2) == 0, "k_internal must be even"
return nr * (
(k_internal // 2)
+ kai_num_bytes_multiplier_rhs
+ kai_num_bytes_sum_rhs
+ kai_num_bytes_bias
)
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
n, k, nr, kr, sr
):
num_rows = kai_roundup(n, nr) // nr
return (
num_rows
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
k, nr, kr, sr
)
)
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
N, K, kai_nr, kai_kr, kai_sr
)
elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
kai_nr = 8
kai_kr = 16
kai_sr = 2
kai_num_bytes_sum_rhs = 4
kai_num_bytes_bias = 4
kai_nr_multiple_of = 4
kai_bl_multiple_of = 32
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
n, k, nr, kr, sr, bl
):
assert (bl % kr) == 0
assert (nr % kai_nr_multiple_of) == 0
assert (bl % kai_bl_multiple_of) == 0
num_rows = kai_roundup(n, nr) // nr
return (
num_rows
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
k, nr, kr, sr, bl
)
)
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
k, nr, kr, sr, bl
):
assert (bl % kr) == 0
assert (nr % kai_nr_multiple_of) == 0
assert (bl % kai_bl_multiple_of) == 0
# kr and sr are unused in the calculation
num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
num_blocks_per_row = kai_num_blocks_per_row(k, bl)
num_bytes_per_block = kai_num_bytes_per_block(
bl, num_bytes_multiplier_rhs
)
return nr * (
(num_bytes_per_block * num_blocks_per_row)
+ kai_num_bytes_sum_rhs
+ kai_num_bytes_bias
)
# This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype
# https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
def kai_get_bf16_datatype_size_in_bytes():
return 2 # 2 bytes
def kai_num_blocks_per_row(k, bl):
assert (bl % kai_bl_multiple_of) == 0
return kai_roundup(k, bl) // bl
def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
assert (bl % kai_bl_multiple_of) == 0
return (bl // 2) + num_bytes_multiplier_rhs
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
N, K, kai_nr, kai_kr, kai_sr, groupsize
)
@register_meta([aten._dyn_quant_pack_4bit_weight])
def meta__dyn_quant_pack_4bit_weight(
weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features
):
torch._check(
weights.dtype is torch.uint8,
lambda: f"expected w to be uint8, got {weights.dtype}",
)
if torch.backends.kleidiai.is_available() and (
(block_size == in_features and scales_zeros.dtype == torch.float)
or (
block_size < in_features
and block_size % 32 == 0
and in_features % block_size == 0
and scales_zeros.dtype == torch.bfloat16
)
):
packed_weight_size = get_kai_packed_weight_size(
4, out_features, in_features, block_size
)
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
packed_weight_size = weights.numel() + scales_zeros.numel()
return weights.new_empty(packed_weight_size, dtype=torch.float)
@register_meta([aten._dyn_quant_matmul_4bit])
def meta__dyn_quant_matmul_4bit(
inp,
packed_weights,
block_size,
in_features,
out_features,
):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check(
inp.dtype in [torch.float32],
lambda: f"expected input to be f32, got {inp.dtype}",
)
M = inp.size(0)
return inp.new_empty(M, out_features, dtype=inp.dtype)
@register_meta([aten._weight_int8pack_mm]) @register_meta([aten._weight_int8pack_mm])
def meta__weight_int8pack_mm(x, w, q_scales): def meta__weight_int8pack_mm(x, w, q_scales):
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")

View File

@ -62,6 +62,7 @@ from torch.backends import (
cuda as cuda, cuda as cuda,
cudnn as cudnn, cudnn as cudnn,
cusparselt as cusparselt, cusparselt as cusparselt,
kleidiai as kleidiai,
mha as mha, mha as mha,
mkl as mkl, mkl as mkl,
mkldnn as mkldnn, mkldnn as mkldnn,

View File

@ -0,0 +1,7 @@
# mypy: allow-untyped-defs
import torch
def is_available():
r"""Return whether PyTorch is built with KleidiAI support."""
return torch._C._has_kleidiai

View File

@ -1939,6 +1939,8 @@ Call this whenever a new thread is created in order to propagate values from
ASSERT_TRUE( ASSERT_TRUE(
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
ASSERT_TRUE(
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
ASSERT_TRUE( ASSERT_TRUE(
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));

View File

@ -18,6 +18,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);

View File

@ -498,6 +498,39 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
return out, scales_and_zeros return out, scales_and_zeros
def _group_quantize_tensor_symmetric(
w, n_bit=4, groupsize=32
):
# W is of shape [K x N]
# We transpose W as Quantization is applied on [N x K]
w = w.transpose(0, 1).contiguous()
assert w.dim() == 2
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
# Calculate scale and zeros
to_quant = w.reshape(-1, groupsize)
max_val = to_quant.abs().amax(dim=1, keepdim=True)
eps = torch.finfo(max_val.dtype).eps
max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7
scales = max_val.clamp(min=eps) / max_int
zeros = torch.zeros_like(scales)
# Quantize the weight
scales = scales.to(torch.float32).reshape(w.shape[0], -1)
zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
max_int = 2**n_bit - 1
w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
# We pack 2 signed int4 values in unsigned uint8 container.
# This reduces the weight size by half and improves load perf
out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
scales_and_zeros = scales.squeeze().contiguous()
return out_uint8, scales_and_zeros
def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
# default setup for affine quantization of activations # default setup for affine quantization of activations
@ -530,7 +563,6 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
return quant, scales.to(x_dtype), zero_points return quant, scales.to(x_dtype), zero_points
# QuantizationTestCase used as a base class for testing quantization on modules # QuantizationTestCase used as a base class for testing quantization on modules
class QuantizationTestCase(TestCase): class QuantizationTestCase(TestCase):
def setUp(self): def setUp(self):

View File

@ -45,6 +45,8 @@ inductor_fallback_ops = {
"aten.cummin.default", "aten.cummin.default",
"aten.cumprod.default", "aten.cumprod.default",
"aten.cumsum.default", "aten.cumsum.default",
"aten._dyn_quant_matmul_4bit.default",
"aten._dyn_quant_pack_4bit_weight.default",
"aten._efficient_attention_backward.default", "aten._efficient_attention_backward.default",
"aten._efficient_attention_forward.default", "aten._efficient_attention_forward.default",
"aten._efficientzerotensor.default", "aten._efficientzerotensor.default",