Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"

This reverts commit 4b82251011f85f9d1395b451d61e976af844d9b1.

Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks lots of internal build ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2555953189))
This commit is contained in:
PyTorch MergeBot
2024-12-19 23:33:17 +00:00
parent 145fd5bad0
commit 8136daff5a
37 changed files with 23 additions and 1898 deletions

3
.gitmodules vendored
View File

@ -131,6 +131,3 @@
path = third_party/composable_kernel
url = https://github.com/ROCm/composable_kernel.git
branch = develop
[submodule "third_party/kleidiai"]
path = third_party/kleidiai
url = https://git.gitlab.arm.com/kleidi/kleidiai.git

View File

@ -254,7 +254,6 @@ filegroup(
# target that generates these sources...
)
# TODO: Enable support for KleidiAI bazel build
header_template_rule(
name = "aten_src_ATen_config",
src = "aten/src/ATen/Config.h.in",
@ -274,7 +273,6 @@ header_template_rule(
"@AT_PARALLEL_NATIVE@": "1",
"@AT_BLAS_F2C@": "0",
"@AT_BLAS_USE_CBLAS_DOT@": "1",
"@AT_KLEIDIAI_ENABLED@": "0",
},
)

View File

@ -377,8 +377,6 @@ cmake_dependent_option(
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
OFF "USE_CUDA" OFF)
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
"CPU_AARCH64" OFF)
option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance
@ -420,8 +418,6 @@ endif()
if(WIN32)
set(USE_TENSORPIPE OFF)
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
set(USE_KLEIDIAI OFF)
message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF")
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
find_library(
@ -671,9 +667,6 @@ if(ANDROID
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
set(BUILD_LAZY_TS_BACKEND OFF)
set(USE_KLEIDIAI OFF)
message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF")
# Set -ffunction-sections and -fdata-sections so that each method has its own
# text section. This allows the linker to remove unused section when the flag
# -Wl,-gc-sections is provided at link time.

View File

@ -309,12 +309,6 @@ local_repository(
path = "third_party/gemmlowp/gemmlowp",
)
local_repository(
name = "kleidiai",
path = "third_party/kleidiai",
repo_mapping = {"@com_google_googletest": "@com_google_benchmark"},
)
### Unused repos start
# `unused` repos are defined to hide bazel files from submodules of submodules.

View File

@ -199,10 +199,6 @@ endif()
# XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
# KLEIDIAI
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
file(GLOB native_kleidiai_h "native/kleidiai/*.h")
# Add files needed from jit folders
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
append_filelist("jit_core_sources" ATen_CORE_SRCS)
@ -232,10 +228,6 @@ endif()
if(AT_MKL_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
endif()
if(AT_KLEIDIAI_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
endif()
if(AT_MKLDNN_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
endif()
@ -619,7 +611,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
if(NOT INTERN_BUILD_MOBILE)
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
# Metal
if(USE_PYTORCH_METAL_EXPORT)
# Add files needed from exporting metal models(optimized_for_mobile)

View File

@ -19,4 +19,3 @@
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
#define AT_BLAS_F2C() @AT_BLAS_F2C@
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@

View File

@ -376,10 +376,6 @@ bool Context::hasMKLDNN() {
#endif
}
bool Context::hasKleidiAI() {
return AT_KLEIDIAI_ENABLED();
}
bool Context::hasOpenMP() {
#ifdef _OPENMP
return true;

View File

@ -118,7 +118,6 @@ class TORCH_API Context {
static bool hasOpenMP();
static bool hasMKL();
static bool hasKleidiAI();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool hasMAGMA() {
@ -539,10 +538,6 @@ inline bool hasMKL() {
return globalContext().hasMKL();
}
inline bool hasKleidiAI() {
return globalContext().hasKleidiAI();
}
inline bool hasLAPACK() {
return globalContext().hasLAPACK();
}

View File

@ -33,8 +33,6 @@
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_compute_linear_combination_native.h>
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
#include <ATen/ops/_dyn_quant_matmul_4bit_native.h>
#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h>
#include <ATen/ops/_int_mm_native.h>
#include <ATen/ops/_linalg_check_errors.h>
#include <ATen/ops/_linalg_det.h>
@ -3431,8 +3429,6 @@ Tensor kron(const Tensor& self, const Tensor& other) {
DEFINE_DISPATCH(weight_to_int4pack_stub);
DEFINE_DISPATCH(int4pack_mm_stub);
DEFINE_DISPATCH(int8pack_mm_stub);
DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub);
DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub);
Tensor _convert_weight_to_int4pack_cpu(
const Tensor& in,
@ -3496,69 +3492,6 @@ Tensor _weight_int4pack_mm_cpu(
return C;
}
Tensor _dyn_quant_pack_4bit_weight_cpu(
const Tensor& weights,
const Tensor& scales_zeros,
const std::optional<Tensor>& bias,
const int64_t block_size,
const int64_t in_features,
const int64_t out_features) {
TORCH_CHECK(
weights.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),
__func__,
": Group size should be multiple of 32, in_features [",
in_features,
"]. Provided ",
block_size);
Tensor packed_weights =
at::empty(weights.sizes(), weights.options().dtype(at::kByte));
dyn_quant_pack_4bit_weight_stub(
kCPU,
packed_weights,
weights,
scales_zeros,
bias,
out_features,
in_features,
block_size);
return packed_weights;
}
Tensor _dyn_quant_matmul_4bit_cpu(
const Tensor& inp,
const Tensor& packed_weights,
const int64_t block_size,
const int64_t in_features,
const int64_t out_features) {
auto M = inp.size(0);
TORCH_CHECK(
inp.dtype() == kFloat,
__func__,
" : expect input to be 32-bit float tensor.");
TORCH_CHECK(
block_size == in_features ||
(!(block_size % 32) && !(in_features % block_size)),
__func__,
": Group size should be multiple of 32, in_features [",
in_features,
"]. Provided ",
block_size);
auto output = at::empty({M, out_features}, inp.options());
dyn_quant_matmul_4bit_stub(
kCPU,
output,
inp,
packed_weights,
M,
out_features,
in_features,
block_size);
return output;
}
Tensor _weight_int8pack_mm_cpu(
const Tensor& A,
const Tensor& B,

View File

@ -8,19 +8,8 @@
#include <ATen/cpu/vec/vec.h>
#include <ATen/native/cpu/int_mm_kernel.h>
#include <ATen/native/cpu/utils.h>
#include <c10/util/Unroll.h>
#include <c10/util/irange.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/cat.h>
#endif
#if AT_KLEIDIAI_ENABLED()
#include <ATen/native/kleidiai/kai_kernels.h>
#include <cpuinfo.h>
#endif
#include <c10/util/Unroll.h>
#if (defined(_WIN32) || defined(_WIN64))
#define RESTRICT __restrict
@ -773,457 +762,10 @@ void int4pack_mm_kernel(
}
}
#if AT_KLEIDIAI_ENABLED()
bool can_use_kleidiai(
const at::Tensor& scales_zeros,
const int64_t K,
const int64_t block_size) {
bool ret = false;
if (cpuinfo_has_arm_neon_dot()) {
// The Groupwise kernel requires BFloat16 Scales and Channelwise kernel
// requires Float32 Scales. If not provided, we will use fallback
// implementation.
if ((block_size == K && scales_zeros.dtype() == at::kFloat) ||
((block_size < K && !(block_size % 32) && !(K % block_size)) &&
scales_zeros.dtype() == at::kBFloat16)) {
ret = true;
}
}
return ret;
}
#endif
/**
* The Int4 quantized weights must be represented as a uint8 tensor
* For matrix multiplication with a weight shape of (N x K)
* the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2].
*
* For KleidiAI weight packing, the scales, biases, and Int4 quantized
* weights are packed into a single `packed_weights` structure, optimized for
* Arm instructions.
*
* In the fallback reference kernel, no special packing is required for
* Int4 quantized weights.
*
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
*/
void dyn_quant_pack_4bit_weight_kernel(
Tensor& packed_weights,
const Tensor& weights,
const Tensor& scales_zeros,
const std::optional<Tensor>& bias,
const int64_t N,
const int64_t K,
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
if (can_use_kleidiai(scales_zeros, K, block_size)) {
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
packed_weights.resize_({weight_packed_size});
kleidiai::kai_pack_int4_rhs(
packed_weights, weights, scales_zeros, bias, N, K, block_size);
} else
#endif
{
TORCH_CHECK(
bias.has_value() == 0,
__func__,
" : Bias is unsupported in reference implementation");
packed_weights = packed_weights.to(kFloat);
auto weight_reshaped = weights.view({-1}).to(kFloat);
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
packed_weights.resize_(res.sizes()).copy_(res);
}
}
static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
size_t m,
size_t n,
size_t k,
const float* lhs_f32,
const uint8_t* rhs_qs4cx,
const float* rhs_scales_f32,
float* dst_f32,
float scalar_min,
float scalar_max) {
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
// Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in
// required format for matmul
auto input_quant_pack_8bit_channelwise =
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
const size_t lhs_qa8dx_stride = k;
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
float max0 = -FLT_MAX;
float min0 = FLT_MAX;
// Find min/max for each channel
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
max0 = (std::max)(src0_0, max0);
min0 = (std::min)(src0_0, min0);
}
// Maximum/minimum int8 values
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = (std::min)(0.0f, min0);
const float rmax0 = (std::max)(0.0f, max0);
const float scale0 =
rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
// Reciprocal to quantize
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
const float descaled_min0 = rmin0 * scale0;
const float descaled_max0 = rmax0 * scale0;
const float zero_point_from_min_error0 = qmin + descaled_min0;
const float zero_point_from_max_error0 = qmax + descaled_max0;
float zero_point0 =
zero_point_from_min_error0 + zero_point_from_max_error0 > 0
? qmin - descaled_min0
: qmax - descaled_max0;
zero_point0 = (std::max)(zero_point0, qmin);
zero_point0 = (std::min)(zero_point0, qmax);
// Round to nearest integer
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride;
// LHS offset at the beginning of the row
*((float*)(dst_ptr)) = recip_scale0;
dst_ptr += sizeof(float);
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
dst_ptr += sizeof(int32_t);
// Quantize the channels
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
// Scale the values
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = v0_s32 + nudged_zero_point0;
v0_s32 = (std::max)(v0_s32, static_cast<int32_t>(INT8_MIN));
v0_s32 = (std::min)(v0_s32, static_cast<int32_t>(INT8_MAX));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
}
};
// Dynamically Quantize the float32 input to 8 bit assymetric
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
const size_t lhs_stride =
k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2);
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride;
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
// Main f32 accumulator
int32_t iacc = 0;
const int8_t* lhs_ptr = lhs_ptr_start;
const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
// Get the LHS quantization parameters stored at the
// beginning of each row
const float lhs_scale = *(const float*)lhs_ptr;
lhs_ptr += sizeof(float);
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
lhs_ptr += sizeof(int32_t);
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
// Get the LHS values
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
// Get the RHS values
const uint8_t rhs_byte = rhs_ptr[0];
// Unpack the RHS values
int32_t rhs_v0 = 0;
if ((k_idx % 2) == 0) {
rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
} else {
rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
}
iacc += lhs_v0 * rhs_v0;
iacc += lhs_offset * rhs_v0;
lhs_ptr += 1;
// Increment only when k_idx is not a multiple of 2
rhs_ptr += k_idx % 2;
}
// Get the RHS scale
const float rhs_scale = rhs_scales_f32[n_idx];
float main_acc = iacc * rhs_scale;
main_acc = main_acc * lhs_scale;
// Clamp (min-max) operation
main_acc = (std::max)(main_acc, scalar_min);
main_acc = (std::min)(main_acc, scalar_max);
dst_f32[0] = main_acc;
dst_f32 += 1;
}
}
};
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
size_t m,
size_t n,
size_t k,
size_t bl,
const float* lhs_f32,
const uint8_t* rhs_qs4c32,
const float* rhs_scales_fp32,
float* dst_f32,
float scalar_min,
float scalar_max) {
// Lambda for LHS quantization
auto lhs_quant_pack = [&](size_t m,
size_t k,
const float* lhs_f32,
int8_t* lhs_qa8dx) {
const size_t dst_stride =
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
const float* src_ptr = lhs_f32 + row_idx * k;
float max0 = -FLT_MAX;
float min0 = FLT_MAX;
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
max0 = (std::max)(src0_0, max0);
min0 = (std::min)(src0_0, min0);
}
const float qmin = (float)INT8_MIN;
const float qmax = (float)INT8_MAX;
const float rmin0 = (std::min)(0.0f, min0);
const float rmax0 = (std::max)(0.0f, max0);
const float scale0 =
(rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
const float descaled_min0 = rmin0 * scale0;
const float descaled_max0 = rmax0 * scale0;
float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0)
? qmin - descaled_min0
: qmax - descaled_max0;
zero_point0 = (std::max)(zero_point0, qmin);
zero_point0 = (std::min)(zero_point0, qmax);
const int32_t nudged_zero_point0 = lrintf(zero_point0);
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride;
*((float*)(dst_ptr)) = recip_scale0;
dst_ptr += sizeof(float);
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
dst_ptr += sizeof(int32_t);
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
const float src0_0 = src_ptr[k_idx];
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
v0_s32 = (std::max)(
(std::min)(
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
static_cast<int32_t>(INT8_MIN));
dst_ptr[0] = (int8_t)v0_s32;
dst_ptr += sizeof(int8_t);
}
}
};
auto lhs_qa8dx_buffer = std::make_unique<int8_t[]>(
m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS
int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
// Quantize and pack LHS
lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx);
const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl;
const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t);
const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2;
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride;
for (size_t col_idx = 0; col_idx < n; ++col_idx) {
float main_acc = 0.0f;
const int8_t* lhs_ptr = lhs_ptr_start;
const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride;
const float lhs_scale = *(const float*)lhs_ptr;
lhs_ptr += sizeof(float);
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
lhs_ptr += sizeof(int32_t);
for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) {
const float rhs_scale =
rhs_scales_fp32[block_idx + col_idx * num_blocks_row];
int32_t iacc = 0;
for (size_t i = 0; i < bl; ++i) {
const size_t k_idx = block_idx * bl + i;
if (k_idx >= k) {
break;
}
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
const uint8_t rhs_byte = rhs_ptr[0];
int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8)
: (((int32_t)(rhs_byte >> 4)) - 8);
iacc += lhs_v0 * rhs_v0;
iacc += lhs_offset * rhs_v0;
lhs_ptr += 1;
rhs_ptr += (k_idx % 2);
}
main_acc += iacc * rhs_scale;
}
main_acc = main_acc * lhs_scale;
main_acc = (std::max)(main_acc, scalar_min);
main_acc = (std::min)(main_acc, scalar_max);
dst_f32[0] = main_acc;
dst_f32 += 1;
}
}
}
/**
* Dynamic Input Quant 4 bit weights matmul execution flow
(INT4 Weights + FP scales + FP32 Bias)
FP32 Input Packed Buffer
| |
Quantize Cast
to INT8 to INT8
| |
v v
INT8 Input INT8 Weights
\ /
\ /
\ /
INT8 Matrix Multiplication
|
v
FP32 Dequantized and Accumulate in FP32
|
v
FP32 Final Output
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
* Float32 Scales. If not provided, we will use fallback implementation.
*/
void dyn_quant_matmul_4bit_kernel(
const Tensor& output,
const Tensor& inp,
const Tensor& packed_weights,
const int64_t M,
const int64_t N,
const int64_t K,
const int64_t block_size) {
#if AT_KLEIDIAI_ENABLED()
const int64_t weight_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
if (weight_packed_size == packed_weights.numel()) {
// KleidiAI interface intenally handles the Channelwise and groupwise
// distinction
kleidiai::kai_quant_pack_lhs_int4_mm(
output, inp, packed_weights, M, N, K, block_size);
} else
#endif
{
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
const auto weights_size = N * K / 2;
// The weights needs to be in uint8_t data type after quantization
auto extracted_weights =
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
auto float32_scales =
(packed_weights.narrow(
0, weights_size, packed_weights.size(0) - weights_size))
.to(kFloat);
uint8_t* rhs_4bit =
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
if (block_size == K) {
ref_dyn_quant_matmul_4bit_channelwise_kernel(
M,
N,
K,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else if (!(block_size % 32) && !(K % block_size)) {
ref_dyn_quant_matmul_4bit_groupwise_kernel(
M,
N,
K,
block_size,
lhs_f32,
rhs_4bit,
rhs_scales_f32,
dst_f32,
-FLT_MAX,
FLT_MAX);
} else {
TORCH_CHECK(
block_size == K || (!(block_size % 32) && !(K % block_size)),
__func__,
": Group size should be multiple 32 or in_features [",
K,
"]. Provided ",
block_size);
}
}
}
} // anonymous namespace
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel)
} // at::native
C10_DIAGNOSTIC_POP()

View File

@ -5,34 +5,12 @@
namespace at::native {
using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
using int4pack_mm_fn =
void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
using int8pack_mm_fn =
void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
using dyn_quant_pack_4bit_weight_fn = void (*)(
Tensor&,
const Tensor&,
const Tensor&,
const std::optional<Tensor>& bias,
const int64_t,
const int64_t,
const int64_t);
using dyn_quant_matmul_4bit_fn = void (*)(
const Tensor&,
const Tensor&,
const Tensor&,
const int64_t,
const int64_t,
const int64_t,
const int64_t);
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&);
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
DECLARE_DISPATCH(
dyn_quant_pack_4bit_weight_fn,
dyn_quant_pack_4bit_weight_stub);
DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub);
} // namespace at::native

View File

@ -1,440 +0,0 @@
#include <ATen/native/kleidiai/kai_kernels.h>
#include <ATen/native/kleidiai/kai_pack.h>
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
#include <ATen/Parallel.h>
#include <algorithm>
#include <cfloat>
#include <cmath>
#include <unordered_map>
#if AT_KLEIDIAI_ENABLED()
#include <cpuinfo.h>
namespace at::native::kleidiai {
void kai_pack_int4_rhs(
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
kernel_packet, weight_packed, weight, scales, bias, n, k);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
const int64_t rhs_stride = kai_roundup(k, 2) / 2;
const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t);
auto& params = kernel_packet.rhs_pack_params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
params.scale_dt = kai_datatype::kai_dt_bf16;
kai_pack_rhs_groupwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4c32p>(
kernel_packet,
weight_packed,
weight,
scales,
bias,
n,
k,
bl,
rhs_stride,
scale_stride);
}
}
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl) {
size_t packed_size = n * k;
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
// Channelwise
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
} else if (!(bl % 32) && !(k % bl)) {
// Groupwise
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
const auto& ukernel = kernel_packet.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
packed_size = kernel_packet.kai_get_rhs_packed_size(
n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16);
}
return packed_size;
}
static void matmul_channelwise(
kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet,
size_t m_increment,
size_t m_start,
size_t m_per_thread,
size_t n_start,
size_t n_per_thread,
size_t n,
size_t k,
size_t mr,
size_t nr,
size_t kr,
size_t sr,
size_t dst_stride,
size_t lhs_stride,
uint8_t* lhs_native_mtx_f32,
uint8_t* lhs_packed_mtx_qa8dx,
uint8_t* rhs_packed_mtx_qs4cx,
uint8_t* dst_act_mtx_f32) {
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
const float* src_ptr =
(const float*)(lhs_native_mtx_f32 +
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
m_start + m0, lhs_stride));
void* lhs_packed_ptr =
(void*)(lhs_packed_mtx_qa8dx +
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
0, k, mr, kr, sr));
const void* rhs_packed_ptr =
(const void*)((const char*)rhs_packed_mtx_qs4cx +
kernel_packet.ukernel.get_rhs_packed_offset(n_start, k));
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
kernel_packet.ukernel.get_dst_offset(
m_start + m0, n_start, dst_stride));
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
// Run Matmul on Int4 packed weights and Quantized Packed Input
kernel_packet.ukernel.run_matmul(
m_increment,
n_per_thread,
k,
lhs_packed_ptr,
rhs_packed_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
}
static void matmul_groupwise(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
const size_t m,
const size_t num_n_per_thread,
const size_t n_start,
const size_t k,
const size_t bl,
const size_t dst_stride,
const void* lhs_ptr,
uint8_t* rhs_packed,
uint8_t* dst_data) {
const size_t rhs_packed_offset =
ukernel.get_rhs_packed_offset(n_start, k, bl);
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
// Run Matmul on Int4 packed weights and Quantized Packed Input
ukernel.run_matmul(
m,
num_n_per_thread,
k,
bl,
lhs_ptr,
rhs_ptr,
dst_ptr,
dst_stride,
sizeof(float),
-FLT_MAX,
FLT_MAX);
}
struct ThreadDivision {
int64_t num_threads_x;
int64_t num_threads_y;
bool use_gemm; // True if GEMM is selected, false if GEMV is used
};
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
return n - (n >> 1);
}
inline static void adjust_max_threads(int64_t& max_threads) {
// We would not like to round down to nearest power of 2 always
// There can be possible thread split combination between powers of 2 for odd
// shapes
// TODO:: Decide better strategy based on hint of input and weight shapes
max_threads = round_down_to_power_of_2(max_threads);
}
static std::pair<int64_t, int64_t> split_2d(const int64_t max_threads) {
int64_t sqrt_threads = std::sqrt(max_threads);
for (int64_t i = sqrt_threads; i >= 1; --i) {
if (max_threads % i == 0) {
return {i, max_threads / i};
}
}
return {1, max_threads}; // Theres still a possibility of 1D blocking when
// calling GEMM kernel
}
inline static ThreadDivision get_thread_division(
int64_t max_threads,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t gemm_m_step,
const int64_t gemm_n_step,
const int64_t gemv_m_step,
const int64_t gemv_n_step) {
adjust_max_threads(max_threads);
ThreadDivision division{1, 1, false};
// Split threads 2D for GEMM
if (m % gemm_m_step == 0 && n % gemm_n_step == 0) {
while (max_threads > 0) {
auto [num_thread_y, num_thread_x] = split_2d(max_threads);
if (m % num_thread_y == 0 && n % num_thread_x == 0) {
int64_t m_per_thread = m / num_thread_y;
int64_t n_per_thread = n / num_thread_x;
if (m_per_thread % gemm_m_step == 0 &&
n_per_thread % gemm_n_step == 0) {
division = {num_thread_x, num_thread_y, true};
return division;
}
}
max_threads -= 2;
}
}
// Split threads 1D for GEMV
if (n % gemv_n_step == 0) {
for (; max_threads > 0; max_threads -= 2) {
if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) {
division.num_threads_x = max_threads;
return division;
}
}
}
return division;
}
static void kai_quant_pack_lhs_int4_mm_groupwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
kai_kernel_id id = kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
if (cpuinfo_has_arm_i8mm() && m > 1) {
id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
}
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
const auto& ukernel = kernel_packet.ukernel;
const size_t mr = ukernel.get_mr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
const size_t n_step = ukernel.get_n_step();
int64_t total_threads = at::get_num_threads();
int64_t num_threads_x = 1;
adjust_max_threads(total_threads);
// Split threads 1D only for now
if (n % n_step == 0) {
for (; total_threads > 0; total_threads -= 2) {
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
num_threads_x = total_threads;
break;
}
}
}
const size_t num_n_per_thread = n / num_threads_x;
const size_t dst_stride = n * sizeof(float);
float* lhs = reinterpret_cast<float*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
// Quantize and pack the Input
kernel_packet.kai_run_lhs_quant_pack(
m,
k,
mr,
kr,
sr,
0,
(const float*)lhs,
k * sizeof(float),
(void*)lhs_packed.get());
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
for (const auto x : c10::irange(begin, end)) {
matmul_groupwise(
std::ref(ukernel),
m,
num_n_per_thread,
x * num_n_per_thread,
k,
bl,
dst_stride,
lhs_packed.get(),
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
static void kai_quant_pack_lhs_int4_mm_channelwise(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k) {
// Kernel IDs for GEMM and GEMV
kai_kernel_id gemm_id =
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm;
kai_kernel_id gemv_id =
kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod;
// Get the total number of threads available and choose GEMM or GEMV steps
const int64_t total_threads = at::get_num_threads();
auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
if (cpuinfo_has_arm_i8mm()) {
gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id);
}
auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
// Retrieve m_step and n_step values from GEMM and GEMV kernels
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
// Determine threading and kernel type
ThreadDivision division = get_thread_division(
total_threads,
m,
n,
k,
gemm_m_step,
gemm_n_step,
gemv_m_step,
gemv_n_step);
// Select appropriate kernel packet based on the chosen kernel type
auto& kernel_packet =
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
// Thread blocking parameters
const size_t mr = kernel_packet.ukernel.get_mr();
const size_t nr = kernel_packet.ukernel.get_nr();
const size_t kr = kernel_packet.ukernel.get_kr();
const size_t sr = kernel_packet.ukernel.get_sr();
const size_t m_increment = kernel_packet.ukernel.get_m_step();
const size_t n_per_thread = n / division.num_threads_x;
const size_t m_per_thread = m / division.num_threads_y;
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
const size_t dst_stride = n * sizeof(float);
const size_t lhs_stride = k * sizeof(float);
const size_t lhs_packed_size =
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
uint8_t* lhs_packed_base = lhs_packed.get();
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
for (const auto i : c10::irange(begin, end)) {
size_t y = i / division.num_threads_x;
size_t x = i % division.num_threads_x;
uint8_t* lhs_packed_ptr =
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
matmul_channelwise(
std::ref(kernel_packet),
m_increment,
y * m_per_thread,
m_per_thread,
x * n_per_thread,
n_per_thread,
n,
k,
mr,
nr,
kr,
sr,
dst_stride,
lhs_stride,
lhs_native_mtx_f32,
lhs_packed_ptr,
rhs_packed_mtx_qs4cx,
dst_act_mtx_f32);
}
});
}
void kai_quant_pack_lhs_int4_mm(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl) {
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
if (bl == k) {
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
output, input, weight, m, n, k);
} else if (!(bl % 32) && !(k % bl)) {
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
output, input, weight, m, n, k, bl);
}
}
} // namespace at::native::kleidiai
#endif

View File

@ -1,42 +0,0 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
/**
* @brief Rearranges the quantized weight to support kleidiai inference
* @param bl Groupsize for quantization should be multiple of 32
*/
void kai_pack_int4_rhs(
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl);
/**
* @brief Outputs the buffer size for the packed weights
* @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise
*/
size_t kai_pack_rhs_int4_size(
const int64_t n,
const int64_t k,
const int64_t bl);
/**
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
*/
void kai_quant_pack_lhs_int4_mm(
const Tensor& output,
const Tensor& input,
const Tensor& weight,
const int64_t m,
const int64_t n,
const int64_t k,
const int64_t bl);
} // namespace at::native::kleidiai
#endif

View File

@ -1,106 +0,0 @@
#pragma once
#include <ATen/Config.h>
#include <ATen/core/Tensor.h>
#include <ATen/ops/empty.h>
#include <torch/library.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
template <typename T>
void kai_pack_rhs_groupwise_int4(
T& kernel,
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k,
const int64_t bl,
const int64_t rhs_stride,
const int64_t scale_stride) {
const auto& ukernel = kernel.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
auto scales_data = scales.const_data_ptr();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
}
if (scales_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
/*num_groups=*/1,
n,
k,
nr,
kr,
sr,
bl,
(const uint8_t*)(weight_data),
rhs_stride,
bias_ptr,
scales_data,
scale_stride,
weight_packed_data,
0,
&params);
}
template <typename T>
void kai_pack_rhs_channelwise_int4(
T& kernel,
const Tensor& weight_packed,
const Tensor& weight,
const Tensor& scales,
const std::optional<Tensor>& bias,
const int64_t n,
const int64_t k) {
const auto& ukernel = kernel.ukernel;
const size_t nr = ukernel.get_nr();
const size_t kr = ukernel.get_kr();
const size_t sr = ukernel.get_sr();
auto weight_packed_data =
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
const auto weight_data = weight.data_ptr<uint8_t>();
const auto scales_data = scales.data_ptr<float>();
if (weight_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
}
if (scales_data == nullptr) {
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
}
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
auto& params = kernel.rhs_pack_params;
kernel.kai_run_rhs_pack(
/*num_groups=*/1,
n,
k,
nr,
kr,
sr,
(const uint8_t*)(weight_data),
(const float*)(bias_ptr),
(const float*)(scales_data),
weight_packed_data,
0,
&params);
}
} // namespace at::native::kleidiai
#endif

View File

@ -1,72 +0,0 @@
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
#if AT_KLEIDIAI_ENABLED()
namespace at::native::kleidiai {
// Kernel Mapping - Groupwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4c32p> groupwise_8bit_4bit_kernels =
{{kai_kernel_id::
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}};
kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
kai_kernel_id id) {
return groupwise_8bit_4bit_kernels.at(id);
}
// Kernel Mapping - Channelwise
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4cxp> channelwise_8bit_4bit_kernels =
{{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}},
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}};
kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
const kai_kernel_id id) {
return channelwise_8bit_4bit_kernels.at(id);
}
} // namespace at::native::kleidiai
#endif

View File

@ -1,144 +0,0 @@
#pragma once
#include <ATen/Config.h>
#include <unordered_map>
#if AT_KLEIDIAI_ENABLED()
#include <kai_common.h>
#include <kai_lhs_quant_pack_qai8dxp_f32.h>
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
#include <kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
#include <kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
#include <kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
#include <kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
namespace at::native::kleidiai {
enum class kai_kernel_id {
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
0, // Groupwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
1, // Groupwise 4 bit GEMM
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
2, // Channelwise 4 bit GEMV
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
3 // Channelwise 4 bit GEMM
};
// Channelwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const float* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
const uint8_t* rhs,
const float* bias,
const float* scale,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
// Groupwise Kernel mapping
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params;
size_t (*kai_get_lhs_packed_size)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr);
size_t (*kai_get_rhs_packed_size)(
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
enum kai_datatype scale_dt);
void (*kai_run_lhs_quant_pack)(
size_t m,
size_t k,
size_t mr,
size_t kr,
size_t sr,
size_t m_idx_start,
const float* lhs,
size_t lhs_stride,
void* lhs_packed);
void (*kai_run_rhs_pack)(
size_t num_groups,
size_t n,
size_t k,
size_t nr,
size_t kr,
size_t sr,
size_t bl,
const uint8_t* rhs,
size_t rhs_stride,
const float* bias,
const void* scale,
size_t scale_stride,
void* rhs_packed,
size_t extra_bytes,
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
: ukernel(kernel),
kai_get_lhs_packed_size(
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
kai_get_rhs_packed_size(
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
};
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
const kai_kernel_id id);
} // namespace at::native::kleidiai
#endif

View File

@ -4177,14 +4177,6 @@
dispatch:
CPU: _weight_int4pack_mm_cpu
- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor
dispatch:
CPU: _dyn_quant_pack_4bit_weight_cpu
- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor
dispatch:
CPU: _dyn_quant_matmul_4bit_cpu
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
dispatch:
CPU: _weight_int8pack_mm_cpu

View File

@ -1070,7 +1070,6 @@ def define_buck_targets(
],
)
# TODO: Enable support for KleidiAI bazel build
# @lint-ignore BUCKLINT
fb_native.genrule(
name = "generate_aten_config",
@ -1123,9 +1122,6 @@ def define_buck_targets(
"--replace",
"@AT_BLAS_USE_CBLAS_DOT@",
"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
"--replace",
"@AT_KLEIDIAI_ENABLED@",
"0",
]),
outs = {
"Config.h": ["Config.h"],

View File

@ -152,7 +152,6 @@ endif()
set(AT_MKLDNN_ACL_ENABLED 0)
set(AT_MKLDNN_ENABLED 0)
set(AT_MKL_ENABLED 0)
set(AT_KLEIDIAI_ENABLED 0)
# setting default preferred BLAS options if not already present.
if(NOT INTERN_BUILD_MOBILE)
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
@ -1481,35 +1480,6 @@ if(NOT INTERN_BUILD_MOBILE)
message("disabling MKLDNN because USE_MKLDNN is not set")
endif()
if(USE_KLEIDIAI)
if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
set(AT_KLEIDIAI_ENABLED 1)
set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests
set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai")
add_subdirectory(${KLEIDIAI_SRC})
set(KLEIDIAI_INCLUDE_DIRS
${KLEIDIAI_SRC}/
${KLEIDIAI_SRC}/kai/
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/
)
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai)
# Recover build options.
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
endif()
if(UNIX AND NOT APPLE)
include(CheckLibraryExists)
# https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830

View File

@ -153,9 +153,6 @@ function(caffe2_print_configuration_summary)
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
endif()
if(${USE_KLEIDIAI})
message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}")
endif()
message(STATUS " USE_UCC : ${USE_UCC}")
if(${USE_UCC})
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")

View File

@ -97,10 +97,6 @@ else()
append_torchlib_if_found(microkernels-prod)
endif()
if(@USE_KLEIDIAI@)
append_torchlib_if_found(kleidiai)
endif()
append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
append_torchlib_if_found(onnx onnx_proto)

View File

@ -203,7 +203,6 @@ torch.backends.openmp
.. add anything to the rendered page for now.
.. py:module:: torch.backends.quantized
.. py:module:: torch.backends.xnnpack
.. py:module:: torch.backends.kleidiai
torch.backends.opt_einsum

View File

@ -1221,7 +1221,6 @@ def main():
"include/ATen/native/cuda/*.cuh",
"include/ATen/native/hip/*.h",
"include/ATen/native/hip/*.cuh",
"include/ATen/native/kleidiai/*.h",
"include/ATen/native/mps/*.h",
"include/ATen/native/mkldnn/xpu/*.h",
"include/ATen/native/mkldnn/xpu/detail/*.h",

View File

@ -92,8 +92,6 @@ aten::_dimI
aten::_dimV
aten::_dirichlet_grad
aten::_dirichlet_grad.out
aten::_dyn_quant_matmul_4bit
aten::_dyn_quant_pack_4bit_weight
aten::_efficient_attention_backward
aten::_efficient_attention_forward
aten::_efficientzerotensor

View File

@ -76,7 +76,6 @@ from torch.testing._internal.common_device_type import (
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
from torch.testing._internal.common_quantization import (
_dynamically_quantize_per_channel,
_group_quantize_tensor_symmetric,
)
from torch.testing._internal.common_utils import (
DeterministicGuard,
@ -2225,85 +2224,6 @@ class CommonTemplate:
b_int8pack, b_scales = convert_weight_to_int8pack(b)
self.common(fn, (a, b_int8pack, b_scales, c))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
def test__dyn_quant_pack_4bit_weight(self):
q_group = 32
k = 128
n = 128
torch.manual_seed(1)
b = torch.rand((k, n), dtype=torch.float32)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(b, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
return b_int4pack
self.common(fn, (b, in_features, out_features))
@xfail_if_triton_cpu
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
@skipIfRocm
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
def test__dyn_quant_matmul_4bit(self):
q_group = 32
m = 32
k = 128
n = 128
torch.manual_seed(1)
a = torch.rand((m, k), dtype=torch.float32)
b = torch.rand((k, n), dtype=torch.float32)
in_features = b.size(0)
out_features = b.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def fn(a, q_group, in_features, out_features):
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
res = torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
return res
self.common(fn, (a, q_group, in_features, out_features))
def test_expanded_reduction(self):
def fn(x, y):
z = x * y

View File

@ -34,8 +34,7 @@ from torch.testing._internal.common_dtype import (
)
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \
_group_quantize_tensor_symmetric
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
from torch.testing._internal.common_mkldnn import bf32_on_and_off
from torch.distributions.binomial import Binomial
import torch.backends.opt_einsum as opt_einsum
@ -910,6 +909,7 @@ class TestLinalg(TestCase):
torch.randn((3, 52, 52), device=device, dtype=dtype),
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
ops = (torch.det, torch.Tensor.det,
torch.linalg.det)
for t in tensors:
@ -1438,6 +1438,7 @@ class TestLinalg(TestCase):
continue
run_test_case(make_arg(shape), ord, dim, keepdim)
@onlyCUDA
@dtypes(torch.bfloat16, torch.float16)
def test_norm_fused_type_promotion(self, device, dtype):
@ -4343,6 +4344,7 @@ class TestLinalg(TestCase):
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
upper, unitriangular, transpose)
@slowTest
@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@ -4463,6 +4465,7 @@ class TestLinalg(TestCase):
self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
def check_single_matmul(self, x, y):
def assertEqual(answer, expected):
@ -5686,6 +5689,7 @@ class TestLinalg(TestCase):
else:
self.assertEqual(B_, X_ @ A)
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
# Non pivoting just implemented for CUDA
@ -5718,6 +5722,7 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
f(torch.empty(1, 2, 2), pivot=False)
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
@skipCUDAIfNoMagmaAndNoCusolver
@skipCPUIfNoLapack
@ -5745,6 +5750,7 @@ class TestLinalg(TestCase):
for b, n in shapes:
yield make_arg((b, n, n)), make_arg((b, n, rhs))
for A, B in gen_matrices():
LU, pivots = torch.linalg.lu_factor(A)
for backend in backends:
@ -5759,6 +5765,7 @@ class TestLinalg(TestCase):
else:
self.assertEqual(B_left, X @ A_adj)
@onlyCPU
@dtypes(*floating_and_complex_types())
def test_linalg_lu_cpu_errors(self, device, dtype):
@ -5799,6 +5806,7 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
torch.lu_unpack(LU, pivots)
# Rectangular tests
sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
B = torch.randn(2, 3, 5, device=device, dtype=dtype)
@ -5815,6 +5823,7 @@ class TestLinalg(TestCase):
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
torch.lu_unpack(LU, pivots)
@skipCPUIfNoLapack
@skipCUDAIfNoMagma
@dtypes(torch.double)
@ -6423,6 +6432,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
self.assertEqual(out, y_ref)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyCUDA
def test_matmul_45724(self, device):
@ -6595,6 +6605,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
torch._int_mm(a_int8, b_int8, out=c_int32_result)
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes
@ -6657,6 +6668,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@onlyNativeDeviceTypes
@ -6709,168 +6721,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("k", [64, 256])
@parametrize("n", [32, 48, 64, 128])
def test__dyn_quant_pack_4bit_weight(self, device, k, n):
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
# Weight shape is [K x N]
if self.device_type == "cuda":
self.skipTest("CUDA Backend is unsupported")
torch.manual_seed(1)
block_size = 32
b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
in_features = b.size(0)
out_features = b.size(1)
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=block_size
)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
)
b_int4pack_meta = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
)
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
q_group = 32
torch.manual_seed(1)
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
in_features = b_float32.size(0)
out_features = b_float32.size(1)
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return b_int4pack, b_scales_and_zeros
def dyn_quant_matmul_4bit(
a, b_int4pack, q_group, in_features, out_features
):
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
b_float32, in_features, out_features
)
dtypes = [torch.float32]
for dtype in dtypes:
a = a_float32.to(dtype=dtype)
b = b_float32.to(dtype=dtype)
ref = torch.mm(a, b)
res = dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
min=torch.finfo(ref.dtype).eps
)
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
self.assertTrue(
all_elements_within_threshold, "Some elements have error >= 0.06"
)
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
@onlyNativeDeviceTypes
@parametrize("m", [1, 32])
@parametrize("k", [64, 128])
@parametrize("n", [4096, 11008])
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
if self.device_type == "cuda":
self.skipTest("CUDA is unsupported")
q_group = 32
torch.manual_seed(1)
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
in_features = b_float32.size(0)
out_features = b_float32.size(1)
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
b_float32, n_bit=4, groupsize=q_group
)
if q_group == in_features:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
else:
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
@torch.compile
def dyn_quant_matmul_4bit(
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
):
b_int4pack = torch._dyn_quant_pack_4bit_weight(
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
)
return torch._dyn_quant_matmul_4bit(
a,
b_int4pack,
q_group,
in_features,
out_features,
)
res = dyn_quant_matmul_4bit(
a_float32,
b_uint8,
b_scales_and_zeros,
q_group,
in_features,
out_features,
)
ref = torch.mm(a_float32, b_float32)
mean_err = ((res - ref).abs() / ref).mean()
self.assertTrue(mean_err < 0.05)
elementwise_diff = (res - ref).abs()
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
min=torch.finfo(ref.dtype).eps
)
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
self.assertTrue(
all_elements_within_threshold, "Some elements have error >= 0.06"
)
@onlyCPU
@parametrize("m", [32, 64])
@parametrize("k", [32, 64])
@ -8802,6 +8652,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
self.assertEqual(ck_out, cpu_out)
def test_permute_matmul(self):
a = torch.ones([2, 5, 24, 24])
b = torch.ones([3, 2, 5, 24, 24])
@ -8881,6 +8733,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
ref = alpha * A @ B + beta * C
self.assertEqual(rc, ref)
@dtypes(torch.float, torch.double)
@precisionOverride({torch.float32: 1e-4})
def test_1_sized_with_0_strided(self, device, dtype):

Submodule third_party/kleidiai deleted from 202603f38a

View File

@ -1299,7 +1299,6 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a
has_openmp: _bool
has_mkl: _bool
_has_kleidiai: _bool
_has_mps: _bool
has_lapack: _bool
_has_cuda: _bool

View File

@ -1372,8 +1372,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
"torch._dim_arange",
"torch._dirichlet_grad",
"torch._disable_functionalization",
"torch._dyn_quant_matmul_4bit",
"torch._dyn_quant_pack_4bit_weight",
"torch._efficientzerotensor",
"torch._embedding_bag_forward_only",
"torch._embedding_bag",

View File

@ -94,6 +94,3 @@ def register_woq_mm_ops() -> None:
return autotune_select_algorithm(
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
)
lowering.make_fallback(aten._dyn_quant_matmul_4bit)
lowering.make_fallback(aten._dyn_quant_pack_4bit_weight)

View File

@ -3344,161 +3344,6 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
def kai_roundup(a: int, b: int) -> int:
return ((a + b - 1) // b) * b
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
if n_bits == 4:
if groupsize == K: # channelwise
# dotprod params only [1x8x32_neon_dotprod]
kai_nr = 8
kai_kr = 16
kai_sr = 2
kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
kai_num_bytes_bias = 4 # sizeof(float)
def kai_k_roundedup(k, kr, sr):
# Since we pack a float and int32 value at the end of the row,
# we must make sure that k is a multiple of 4 for alignment
kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
return kai_roundup(k, kr_sr_roundedup4)
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
k, nr, kr, sr
):
k_internal = kai_k_roundedup(k, kr, sr)
assert (k_internal % 2) == 0, "k_internal must be even"
return nr * (
(k_internal // 2)
+ kai_num_bytes_multiplier_rhs
+ kai_num_bytes_sum_rhs
+ kai_num_bytes_bias
)
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
n, k, nr, kr, sr
):
num_rows = kai_roundup(n, nr) // nr
return (
num_rows
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
k, nr, kr, sr
)
)
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
N, K, kai_nr, kai_kr, kai_sr
)
elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
kai_nr = 8
kai_kr = 16
kai_sr = 2
kai_num_bytes_sum_rhs = 4
kai_num_bytes_bias = 4
kai_nr_multiple_of = 4
kai_bl_multiple_of = 32
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
n, k, nr, kr, sr, bl
):
assert (bl % kr) == 0
assert (nr % kai_nr_multiple_of) == 0
assert (bl % kai_bl_multiple_of) == 0
num_rows = kai_roundup(n, nr) // nr
return (
num_rows
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
k, nr, kr, sr, bl
)
)
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
k, nr, kr, sr, bl
):
assert (bl % kr) == 0
assert (nr % kai_nr_multiple_of) == 0
assert (bl % kai_bl_multiple_of) == 0
# kr and sr are unused in the calculation
num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
num_blocks_per_row = kai_num_blocks_per_row(k, bl)
num_bytes_per_block = kai_num_bytes_per_block(
bl, num_bytes_multiplier_rhs
)
return nr * (
(num_bytes_per_block * num_blocks_per_row)
+ kai_num_bytes_sum_rhs
+ kai_num_bytes_bias
)
# This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype
# https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
def kai_get_bf16_datatype_size_in_bytes():
return 2 # 2 bytes
def kai_num_blocks_per_row(k, bl):
assert (bl % kai_bl_multiple_of) == 0
return kai_roundup(k, bl) // bl
def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
assert (bl % kai_bl_multiple_of) == 0
return (bl // 2) + num_bytes_multiplier_rhs
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
N, K, kai_nr, kai_kr, kai_sr, groupsize
)
@register_meta([aten._dyn_quant_pack_4bit_weight])
def meta__dyn_quant_pack_4bit_weight(
weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features
):
torch._check(
weights.dtype is torch.uint8,
lambda: f"expected w to be uint8, got {weights.dtype}",
)
if torch.backends.kleidiai.is_available() and (
(block_size == in_features and scales_zeros.dtype == torch.float)
or (
block_size < in_features
and block_size % 32 == 0
and in_features % block_size == 0
and scales_zeros.dtype == torch.bfloat16
)
):
packed_weight_size = get_kai_packed_weight_size(
4, out_features, in_features, block_size
)
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
packed_weight_size = weights.numel() + scales_zeros.numel()
return weights.new_empty(packed_weight_size, dtype=torch.float)
@register_meta([aten._dyn_quant_matmul_4bit])
def meta__dyn_quant_matmul_4bit(
inp,
packed_weights,
block_size,
in_features,
out_features,
):
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
torch._check(
inp.dtype in [torch.float32],
lambda: f"expected input to be f32, got {inp.dtype}",
)
M = inp.size(0)
return inp.new_empty(M, out_features, dtype=inp.dtype)
@register_meta([aten._weight_int8pack_mm])
def meta__weight_int8pack_mm(x, w, q_scales):
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")

View File

@ -62,7 +62,6 @@ from torch.backends import (
cuda as cuda,
cudnn as cudnn,
cusparselt as cusparselt,
kleidiai as kleidiai,
mha as mha,
mkl as mkl,
mkldnn as mkldnn,

View File

@ -1,7 +0,0 @@
# mypy: allow-untyped-defs
import torch
def is_available():
r"""Return whether PyTorch is built with KleidiAI support."""
return torch._C._has_kleidiai

View File

@ -1939,8 +1939,6 @@ Call this whenever a new thread is created in order to propagate values from
ASSERT_TRUE(
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
ASSERT_TRUE(
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
ASSERT_TRUE(
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));

View File

@ -18,8 +18,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);

View File

@ -498,39 +498,6 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
return out, scales_and_zeros
def _group_quantize_tensor_symmetric(
w, n_bit=4, groupsize=32
):
# W is of shape [K x N]
# We transpose W as Quantization is applied on [N x K]
w = w.transpose(0, 1).contiguous()
assert w.dim() == 2
assert groupsize > 1
assert w.shape[-1] % groupsize == 0
# Calculate scale and zeros
to_quant = w.reshape(-1, groupsize)
max_val = to_quant.abs().amax(dim=1, keepdim=True)
eps = torch.finfo(max_val.dtype).eps
max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7
scales = max_val.clamp(min=eps) / max_int
zeros = torch.zeros_like(scales)
# Quantize the weight
scales = scales.to(torch.float32).reshape(w.shape[0], -1)
zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
scales = scales.reshape(-1, 1)
zeros = zeros.reshape(-1, 1)
max_int = 2**n_bit - 1
w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
# We pack 2 signed int4 values in unsigned uint8 container.
# This reduces the weight size by half and improves load perf
out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
scales_and_zeros = scales.squeeze().contiguous()
return out_uint8, scales_and_zeros
def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
# default setup for affine quantization of activations
@ -563,6 +530,7 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
return quant, scales.to(x_dtype), zero_points
# QuantizationTestCase used as a base class for testing quantization on modules
class QuantizationTestCase(TestCase):
def setUp(self):

View File

@ -45,8 +45,6 @@ inductor_fallback_ops = {
"aten.cummin.default",
"aten.cumprod.default",
"aten.cumsum.default",
"aten._dyn_quant_matmul_4bit.default",
"aten._dyn_quant_pack_4bit_weight.default",
"aten._efficient_attention_backward.default",
"aten._efficient_attention_forward.default",
"aten._efficientzerotensor.default",