From d3ff2d42c28a2c187cbedfd8f60b84a4dfa2d6bf Mon Sep 17 00:00:00 2001 From: Nikhil Gupta Date: Wed, 18 Dec 2024 22:30:05 +0000 Subject: [PATCH] [ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet --- .gitmodules | 3 + BUILD.bazel | 2 + CMakeLists.txt | 4 + WORKSPACE | 6 + aten/src/ATen/CMakeLists.txt | 10 +- aten/src/ATen/Config.h.in | 1 + aten/src/ATen/Context.cpp | 4 + aten/src/ATen/Context.h | 5 + aten/src/ATen/native/LinearAlgebra.cpp | 67 +++ aten/src/ATen/native/cpu/int4mm_kernel.cpp | 459 +++++++++++++++++- aten/src/ATen/native/cpu/int_mm_kernel.h | 28 +- aten/src/ATen/native/kleidiai/kai_kernels.cpp | 440 +++++++++++++++++ aten/src/ATen/native/kleidiai/kai_kernels.h | 42 ++ aten/src/ATen/native/kleidiai/kai_pack.h | 106 ++++ .../native/kleidiai/kai_ukernel_interface.cpp | 72 +++ .../native/kleidiai/kai_ukernel_interface.h | 144 ++++++ aten/src/ATen/native/native_functions.yaml | 8 + buckbuild.bzl | 4 + cmake/Dependencies.cmake | 30 ++ cmake/Summary.cmake | 3 + cmake/TorchConfig.cmake.in | 4 + docs/source/backends.rst | 1 + setup.py | 1 + ...asDecompTest.test_has_decomposition.expect | 2 + test/inductor/test_torchinductor.py | 80 +++ test/test_linalg.py | 181 ++++++- third_party/kleidiai | 1 + torch/_C/__init__.pyi.in | 1 + torch/_dynamo/trace_rules.py | 2 + torch/_inductor/quantized_lowerings.py | 3 + torch/_meta_registrations.py | 155 ++++++ torch/backends/__init__.py | 1 + torch/backends/kleidiai/__init__.py | 7 + torch/csrc/Module.cpp | 2 + .../aoti_torch/generated/c_shim_cpu.h | 2 + .../testing/_internal/common_quantization.py | 34 +- torchgen/aoti/fallback_ops.py | 2 + 37 files changed, 1894 insertions(+), 23 deletions(-) create mode 100644 aten/src/ATen/native/kleidiai/kai_kernels.cpp create mode 100644 aten/src/ATen/native/kleidiai/kai_kernels.h create mode 100644 aten/src/ATen/native/kleidiai/kai_pack.h create mode 100644 aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp create mode 100644 aten/src/ATen/native/kleidiai/kai_ukernel_interface.h create mode 160000 third_party/kleidiai create mode 100644 torch/backends/kleidiai/__init__.py diff --git a/.gitmodules b/.gitmodules index 36d5becb57c3..3813fc2b96b3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -131,3 +131,6 @@ path = third_party/composable_kernel url = https://github.com/ROCm/composable_kernel.git branch = develop +[submodule "third_party/kleidiai"] + path = third_party/kleidiai + url = https://git.gitlab.arm.com/kleidi/kleidiai.git diff --git a/BUILD.bazel b/BUILD.bazel index 65e7b391528f..4307e7b09626 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -254,6 +254,7 @@ filegroup( # target that generates these sources... ) +# TODO: Enable support for KleidiAI bazel build header_template_rule( name = "aten_src_ATen_config", src = "aten/src/ATen/Config.h.in", @@ -273,6 +274,7 @@ header_template_rule( "@AT_PARALLEL_NATIVE@": "1", "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", + "@AT_KLEIDIAI_ENABLED@": "0", }, ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1971b2fd9c75..b294cc7b5a4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -377,6 +377,8 @@ cmake_dependent_option( cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF) cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" OFF "USE_CUDA" OFF) +cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON + "CPU_AARCH64" OFF) option(USE_MIMALLOC "Use mimalloc" OFF) # Enable third party mimalloc library to improve memory allocation performance @@ -418,6 +420,8 @@ endif() if(WIN32) set(USE_TENSORPIPE OFF) message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF") + set(USE_KLEIDIAI OFF) + message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF") if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT}) find_library( diff --git a/WORKSPACE b/WORKSPACE index ac06b6bdc5d9..ae7c0644e203 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -309,6 +309,12 @@ local_repository( path = "third_party/gemmlowp/gemmlowp", ) +local_repository( + name = "kleidiai", + path = "third_party/kleidiai", + repo_mapping = {"@com_google_googletest": "@com_google_benchmark"}, +) + ### Unused repos start # `unused` repos are defined to hide bazel files from submodules of submodules. diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index f0868ea04898..442ce7bbe890 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -199,6 +199,10 @@ endif() # XNNPACK file(GLOB native_xnnpack "native/xnnpack/*.cpp") +# KLEIDIAI +file(GLOB native_kleidiai "native/kleidiai/*.cpp") +file(GLOB native_kleidiai_h "native/kleidiai/*.h") + # Add files needed from jit folders append_filelist("jit_core_headers" ATen_CORE_HEADERS) append_filelist("jit_core_sources" ATen_CORE_SRCS) @@ -228,6 +232,10 @@ endif() if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() +if(AT_KLEIDIAI_ENABLED) + set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai}) + include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS}) +endif() if(AT_MKLDNN_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp}) endif() @@ -611,7 +619,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index fdd2ac2bc5f7..c22e15a52aa2 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -19,3 +19,4 @@ #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ +#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@ diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index a222c9ce74c8..3fbf3fbff65d 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -376,6 +376,10 @@ bool Context::hasMKLDNN() { #endif } +bool Context::hasKleidiAI() { + return AT_KLEIDIAI_ENABLED(); +} + bool Context::hasOpenMP() { #ifdef _OPENMP return true; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index ccbefc9105a8..41bfd34583f4 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -118,6 +118,7 @@ class TORCH_API Context { static bool hasOpenMP(); static bool hasMKL(); + static bool hasKleidiAI(); static bool hasLAPACK(); static bool hasMKLDNN(); static bool hasMAGMA() { @@ -538,6 +539,10 @@ inline bool hasMKL() { return globalContext().hasMKL(); } +inline bool hasKleidiAI() { + return globalContext().hasKleidiAI(); +} + inline bool hasLAPACK() { return globalContext().hasLAPACK(); } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index f98f55b1f9f4..69c90bc63f4f 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include #include #include @@ -3429,6 +3431,8 @@ Tensor kron(const Tensor& self, const Tensor& other) { DEFINE_DISPATCH(weight_to_int4pack_stub); DEFINE_DISPATCH(int4pack_mm_stub); DEFINE_DISPATCH(int8pack_mm_stub); +DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub); +DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub); Tensor _convert_weight_to_int4pack_cpu( const Tensor& in, @@ -3492,6 +3496,69 @@ Tensor _weight_int4pack_mm_cpu( return C; } +Tensor _dyn_quant_pack_4bit_weight_cpu( + const Tensor& weights, + const Tensor& scales_zeros, + const std::optional& bias, + const int64_t block_size, + const int64_t in_features, + const int64_t out_features) { + TORCH_CHECK( + weights.dtype() == at::kByte, __func__, " : expect weight to be kByte."); + TORCH_CHECK( + block_size == in_features || + (!(block_size % 32) && !(in_features % block_size)), + __func__, + ": Group size should be multiple of 32, in_features [", + in_features, + "]. Provided ", + block_size); + Tensor packed_weights = + at::empty(weights.sizes(), weights.options().dtype(at::kByte)); + dyn_quant_pack_4bit_weight_stub( + kCPU, + packed_weights, + weights, + scales_zeros, + bias, + out_features, + in_features, + block_size); + return packed_weights; +} + +Tensor _dyn_quant_matmul_4bit_cpu( + const Tensor& inp, + const Tensor& packed_weights, + const int64_t block_size, + const int64_t in_features, + const int64_t out_features) { + auto M = inp.size(0); + TORCH_CHECK( + inp.dtype() == kFloat, + __func__, + " : expect input to be 32-bit float tensor."); + TORCH_CHECK( + block_size == in_features || + (!(block_size % 32) && !(in_features % block_size)), + __func__, + ": Group size should be multiple of 32, in_features [", + in_features, + "]. Provided ", + block_size); + auto output = at::empty({M, out_features}, inp.options()); + dyn_quant_matmul_4bit_stub( + kCPU, + output, + inp, + packed_weights, + M, + out_features, + in_features, + block_size); + return output; +} + Tensor _weight_int8pack_mm_cpu( const Tensor& A, const Tensor& B, diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index 11a34eefb95a..13f4ae78e7ec 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -8,8 +8,18 @@ #include #include #include -#include #include +#include +#include +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +#if AT_KLEIDIAI_ENABLED() +#include +#endif #if (defined(_WIN32) || defined(_WIN64)) #define RESTRICT __restrict @@ -762,10 +772,457 @@ void int4pack_mm_kernel( } } +#if AT_KLEIDIAI_ENABLED() +bool can_use_kleidiai( + const at::Tensor& scales_zeros, + const int64_t K, + const int64_t block_size) { + bool ret = false; + if (cpuinfo_has_arm_neon_dot()) { + // The Groupwise kernel requires BFloat16 Scales and Channelwise kernel + // requires Float32 Scales. If not provided, we will use fallback + // implementation. + if ((block_size == K && scales_zeros.dtype() == at::kFloat) || + ((block_size < K && !(block_size % 32) && !(K % block_size)) && + scales_zeros.dtype() == at::kBFloat16)) { + ret = true; + } + } + return ret; +} +#endif + +/** + * The Int4 quantized weights must be represented as a uint8 tensor + * For matrix multiplication with a weight shape of (N x K) + * the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2]. + * + * For KleidiAI weight packing, the scales, biases, and Int4 quantized + * weights are packed into a single `packed_weights` structure, optimized for + * Arm instructions. + * + * In the fallback reference kernel, no special packing is required for + * Int4 quantized weights. + * + * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires + * Float32 Scales. If not provided, we will use fallback implementation. + */ +void dyn_quant_pack_4bit_weight_kernel( + Tensor& packed_weights, + const Tensor& weights, + const Tensor& scales_zeros, + const std::optional& bias, + const int64_t N, + const int64_t K, + const int64_t block_size) { +#if AT_KLEIDIAI_ENABLED() + if (can_use_kleidiai(scales_zeros, K, block_size)) { + const int64_t weight_packed_size = + kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + packed_weights.resize_({weight_packed_size}); + kleidiai::kai_pack_int4_rhs( + packed_weights, weights, scales_zeros, bias, N, K, block_size); + } else +#endif + { + TORCH_CHECK( + bias.has_value() == 0, + __func__, + " : Bias is unsupported in reference implementation"); + packed_weights = packed_weights.to(kFloat); + auto weight_reshaped = weights.view({-1}).to(kFloat); + auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat); + auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0); + packed_weights.resize_(res.sizes()).copy_(res); + } +} + +static void ref_dyn_quant_matmul_4bit_channelwise_kernel( + size_t m, + size_t n, + size_t k, + const float* lhs_f32, + const uint8_t* rhs_qs4cx, + const float* rhs_scales_f32, + float* dst_f32, + float scalar_min, + float scalar_max) { + const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float)); + + auto lhs_qa8dx_buffer = std::make_unique(input_size_8bit); + uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get(); + + // Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in + // required format for matmul + auto input_quant_pack_8bit_channelwise = + [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { + const size_t dst_stride = + (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + const size_t lhs_qa8dx_stride = k; + + for (size_t m_idx = 0; m_idx < m; ++m_idx) { + const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + max0 = (std::max)(src0_0, max0); + min0 = (std::min)(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = (std::min)(0.0f, min0); + const float rmax0 = (std::max)(0.0f, max0); + + const float scale0 = + rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 + ? qmin - descaled_min0 + : qmax - descaled_max0; + + zero_point0 = (std::max)(zero_point0, qmin); + zero_point0 = (std::min)(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride; + + // LHS offset at the beginning of the row + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + // Quantize the channels + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + + // Scale the values + int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = (std::max)(v0_s32, static_cast(INT8_MIN)); + v0_s32 = (std::min)(v0_s32, static_cast(INT8_MAX)); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } + }; + + // Dynamically Quantize the float32 input to 8 bit assymetric + input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx); + + const size_t lhs_stride = + k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); + + const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2); + + for (size_t m_idx = 0; m_idx < m; ++m_idx) { + const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride; + + for (size_t n_idx = 0; n_idx < n; ++n_idx) { + // Main f32 accumulator + int32_t iacc = 0; + + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride; + + // Get the LHS quantization parameters stored at the + // beginning of each row + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + // Get the LHS values + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + + // Get the RHS values + const uint8_t rhs_byte = rhs_ptr[0]; + + // Unpack the RHS values + int32_t rhs_v0 = 0; + if ((k_idx % 2) == 0) { + rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); + } else { + rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); + } + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + + // Increment only when k_idx is not a multiple of 2 + rhs_ptr += k_idx % 2; + } + + // Get the RHS scale + const float rhs_scale = rhs_scales_f32[n_idx]; + + float main_acc = iacc * rhs_scale; + + main_acc = main_acc * lhs_scale; + + // Clamp (min-max) operation + main_acc = (std::max)(main_acc, scalar_min); + main_acc = (std::min)(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +}; + +static void ref_dyn_quant_matmul_4bit_groupwise_kernel( + size_t m, + size_t n, + size_t k, + size_t bl, + const float* lhs_f32, + const uint8_t* rhs_qs4c32, + const float* rhs_scales_fp32, + float* dst_f32, + float scalar_min, + float scalar_max) { + // Lambda for LHS quantization + auto lhs_quant_pack = [&](size_t m, + size_t k, + const float* lhs_f32, + int8_t* lhs_qa8dx) { + const size_t dst_stride = + (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const float* src_ptr = lhs_f32 + row_idx * k; + + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + max0 = (std::max)(src0_0, max0); + min0 = (std::min)(src0_0, min0); + } + + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = (std::min)(0.0f, min0); + const float rmax0 = (std::max)(0.0f, max0); + const float scale0 = + (rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0); + const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0) + ? qmin - descaled_min0 + : qmax - descaled_max0; + + zero_point0 = (std::max)(zero_point0, qmin); + zero_point0 = (std::min)(zero_point0, qmax); + const int32_t nudged_zero_point0 = lrintf(zero_point0); + + int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; + + *((float*)(dst_ptr)) = recip_scale0; + dst_ptr += sizeof(float); + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + dst_ptr += sizeof(int32_t); + + for (size_t k_idx = 0; k_idx < k; ++k_idx) { + const float src0_0 = src_ptr[k_idx]; + int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); + v0_s32 = (std::max)( + (std::min)( + v0_s32 + nudged_zero_point0, static_cast(INT8_MAX)), + static_cast(INT8_MIN)); + dst_ptr[0] = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + } + }; + + auto lhs_qa8dx_buffer = std::make_unique( + m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS + int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get(); + // Quantize and pack LHS + lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx); + + const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl; + const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); + const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2; + + for (size_t row_idx = 0; row_idx < m; ++row_idx) { + const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; + + for (size_t col_idx = 0; col_idx < n; ++col_idx) { + float main_acc = 0.0f; + const int8_t* lhs_ptr = lhs_ptr_start; + const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride; + + const float lhs_scale = *(const float*)lhs_ptr; + lhs_ptr += sizeof(float); + const int32_t lhs_offset = *(const int32_t*)lhs_ptr; + lhs_ptr += sizeof(int32_t); + + for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { + const float rhs_scale = + rhs_scales_fp32[block_idx + col_idx * num_blocks_row]; + int32_t iacc = 0; + + for (size_t i = 0; i < bl; ++i) { + const size_t k_idx = block_idx * bl + i; + if (k_idx >= k) { + break; + } + + const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; + const uint8_t rhs_byte = rhs_ptr[0]; + int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8) + : (((int32_t)(rhs_byte >> 4)) - 8); + + iacc += lhs_v0 * rhs_v0; + iacc += lhs_offset * rhs_v0; + + lhs_ptr += 1; + rhs_ptr += (k_idx % 2); + } + + main_acc += iacc * rhs_scale; + } + + main_acc = main_acc * lhs_scale; + main_acc = (std::max)(main_acc, scalar_min); + main_acc = (std::min)(main_acc, scalar_max); + + dst_f32[0] = main_acc; + dst_f32 += 1; + } + } +} + +/** + * Dynamic Input Quant 4 bit weights matmul execution flow + (INT4 Weights + FP scales + FP32 Bias) + FP32 Input Packed Buffer + | | + Quantize Cast + to INT8 to INT8 + | | + v v + INT8 Input INT8 Weights + \ / + \ / + \ / + INT8 Matrix Multiplication + | + v + FP32 Dequantized and Accumulate in FP32 + | + v + FP32 Final Output + + * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires + * Float32 Scales. If not provided, we will use fallback implementation. + */ +void dyn_quant_matmul_4bit_kernel( + const Tensor& output, + const Tensor& inp, + const Tensor& packed_weights, + const int64_t M, + const int64_t N, + const int64_t K, + const int64_t block_size) { +#if AT_KLEIDIAI_ENABLED() + const int64_t weight_packed_size = + kleidiai::kai_pack_rhs_int4_size(N, K, block_size); + if (weight_packed_size == packed_weights.numel()) { + // KleidiAI interface intenally handles the Channelwise and groupwise + // distinction + kleidiai::kai_quant_pack_lhs_int4_mm( + output, inp, packed_weights, M, N, K, block_size); + } else +#endif + { + float* lhs_f32 = reinterpret_cast(inp.data_ptr()); + const auto weights_size = N * K / 2; + // The weights needs to be in uint8_t data type after quantization + auto extracted_weights = + (packed_weights.narrow(0, 0, weights_size)).to(kByte); + auto float32_scales = + (packed_weights.narrow( + 0, weights_size, packed_weights.size(0) - weights_size)) + .to(kFloat); + uint8_t* rhs_4bit = + reinterpret_cast(extracted_weights.data_ptr()); + float* rhs_scales_f32 = reinterpret_cast(float32_scales.data_ptr()); + float* dst_f32 = reinterpret_cast(output.data_ptr()); + if (block_size == K) { + ref_dyn_quant_matmul_4bit_channelwise_kernel( + M, + N, + K, + lhs_f32, + rhs_4bit, + rhs_scales_f32, + dst_f32, + -FLT_MAX, + FLT_MAX); + } else if (!(block_size % 32) && !(K % block_size)) { + ref_dyn_quant_matmul_4bit_groupwise_kernel( + M, + N, + K, + block_size, + lhs_f32, + rhs_4bit, + rhs_scales_f32, + dst_f32, + -FLT_MAX, + FLT_MAX); + } else { + TORCH_CHECK( + block_size == K || (!(block_size % 32) && !(K % block_size)), + __func__, + ": Group size should be multiple 32 or in_features [", + K, + "]. Provided ", + block_size); + } + } +} + } // anonymous namespace ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel) ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel) +REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel) +REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel) } // at::native C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/native/cpu/int_mm_kernel.h b/aten/src/ATen/native/cpu/int_mm_kernel.h index 1131aa9b53c9..ee04a18c1df4 100644 --- a/aten/src/ATen/native/cpu/int_mm_kernel.h +++ b/aten/src/ATen/native/cpu/int_mm_kernel.h @@ -5,12 +5,34 @@ namespace at::native { -using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&); -using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); -using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); +using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&); +using int4pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); +using int8pack_mm_fn = + void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); +using dyn_quant_pack_4bit_weight_fn = void (*)( + Tensor&, + const Tensor&, + const Tensor&, + const std::optional& bias, + const int64_t, + const int64_t, + const int64_t); +using dyn_quant_matmul_4bit_fn = void (*)( + const Tensor&, + const Tensor&, + const Tensor&, + const int64_t, + const int64_t, + const int64_t, + const int64_t); DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub) DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub) DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub) +DECLARE_DISPATCH( + dyn_quant_pack_4bit_weight_fn, + dyn_quant_pack_4bit_weight_stub); +DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub); } // namespace at::native diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.cpp b/aten/src/ATen/native/kleidiai/kai_kernels.cpp new file mode 100644 index 000000000000..7872bffca4c9 --- /dev/null +++ b/aten/src/ATen/native/kleidiai/kai_kernels.cpp @@ -0,0 +1,440 @@ +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +void kai_pack_int4_rhs( + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl) { + // Prefer Channelwise kernel over Groupwise kernel for conflicting cases + if (bl == k) { + // Channelwise + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + + kai_pack_rhs_channelwise_int4( + kernel_packet, weight_packed, weight, scales, bias, n, k); + } else if (!(bl % 32) && !(k % bl)) { + // Groupwise + auto kernel_packet = kai_select_groupwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); + + const int64_t rhs_stride = kai_roundup(k, 2) / 2; + const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t); + auto& params = kernel_packet.rhs_pack_params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_datatype::kai_dt_bf16; + + kai_pack_rhs_groupwise_int4( + kernel_packet, + weight_packed, + weight, + scales, + bias, + n, + k, + bl, + rhs_stride, + scale_stride); + } +} + +size_t kai_pack_rhs_int4_size( + const int64_t n, + const int64_t k, + const int64_t bl) { + size_t packed_size = n * k; + // Prefer Channelwise kernel over Groupwise kernel for conflicting cases + if (bl == k) { + // Channelwise + auto kernel_packet = kai_select_channelwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); + } else if (!(bl % 32) && !(k % bl)) { + // Groupwise + auto kernel_packet = kai_select_groupwise_matmul_ukernel( + kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); + const auto& ukernel = kernel_packet.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + packed_size = kernel_packet.kai_get_rhs_packed_size( + n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); + } + return packed_size; +} + +static void matmul_channelwise( + kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet, + size_t m_increment, + size_t m_start, + size_t m_per_thread, + size_t n_start, + size_t n_per_thread, + size_t n, + size_t k, + size_t mr, + size_t nr, + size_t kr, + size_t sr, + size_t dst_stride, + size_t lhs_stride, + uint8_t* lhs_native_mtx_f32, + uint8_t* lhs_packed_mtx_qa8dx, + uint8_t* rhs_packed_mtx_qs4cx, + uint8_t* dst_act_mtx_f32) { + for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) { + const float* src_ptr = + (const float*)(lhs_native_mtx_f32 + + kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32( + m_start + m0, lhs_stride)); + void* lhs_packed_ptr = + (void*)(lhs_packed_mtx_qa8dx + + kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( + 0, k, mr, kr, sr)); + const void* rhs_packed_ptr = + (const void*)((const char*)rhs_packed_mtx_qs4cx + + kernel_packet.ukernel.get_rhs_packed_offset(n_start, k)); + float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + + kernel_packet.ukernel.get_dst_offset( + m_start + m0, n_start, dst_stride)); + + // Quantize and pack the Input + kernel_packet.kai_run_lhs_quant_pack( + m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr); + + // Run Matmul on Int4 packed weights and Quantized Packed Input + kernel_packet.ukernel.run_matmul( + m_increment, + n_per_thread, + k, + lhs_packed_ptr, + rhs_packed_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); + } +} + +static void matmul_groupwise( + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel, + const size_t m, + const size_t num_n_per_thread, + const size_t n_start, + const size_t k, + const size_t bl, + const size_t dst_stride, + const void* lhs_ptr, + uint8_t* rhs_packed, + uint8_t* dst_data) { + const size_t rhs_packed_offset = + ukernel.get_rhs_packed_offset(n_start, k, bl); + const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride); + + const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset); + float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset); + + // Run Matmul on Int4 packed weights and Quantized Packed Input + ukernel.run_matmul( + m, + num_n_per_thread, + k, + bl, + lhs_ptr, + rhs_ptr, + dst_ptr, + dst_stride, + sizeof(float), + -FLT_MAX, + FLT_MAX); +} + +struct ThreadDivision { + int64_t num_threads_x; + int64_t num_threads_y; + bool use_gemm; // True if GEMM is selected, false if GEMV is used +}; + +inline static unsigned int round_down_to_power_of_2(unsigned int n) { + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n - (n >> 1); +} + +inline static void adjust_max_threads(int64_t& max_threads) { + // We would not like to round down to nearest power of 2 always + // There can be possible thread split combination between powers of 2 for odd + // shapes + // TODO:: Decide better strategy based on hint of input and weight shapes + max_threads = round_down_to_power_of_2(max_threads); +} + +static std::pair split_2d(const int64_t max_threads) { + int64_t sqrt_threads = std::sqrt(max_threads); + + for (int64_t i = sqrt_threads; i >= 1; --i) { + if (max_threads % i == 0) { + return {i, max_threads / i}; + } + } + + return {1, max_threads}; // Theres still a possibility of 1D blocking when + // calling GEMM kernel +} + +inline static ThreadDivision get_thread_division( + int64_t max_threads, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t gemm_m_step, + const int64_t gemm_n_step, + const int64_t gemv_m_step, + const int64_t gemv_n_step) { + adjust_max_threads(max_threads); + ThreadDivision division{1, 1, false}; + + // Split threads 2D for GEMM + if (m % gemm_m_step == 0 && n % gemm_n_step == 0) { + while (max_threads > 0) { + auto [num_thread_y, num_thread_x] = split_2d(max_threads); + if (m % num_thread_y == 0 && n % num_thread_x == 0) { + int64_t m_per_thread = m / num_thread_y; + int64_t n_per_thread = n / num_thread_x; + if (m_per_thread % gemm_m_step == 0 && + n_per_thread % gemm_n_step == 0) { + division = {num_thread_x, num_thread_y, true}; + return division; + } + } + max_threads -= 2; + } + } + // Split threads 1D for GEMV + if (n % gemv_n_step == 0) { + for (; max_threads > 0; max_threads -= 2) { + if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) { + division.num_threads_x = max_threads; + return division; + } + } + } + return division; +} + +static void kai_quant_pack_lhs_int4_mm_groupwise( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t bl) { + kai_kernel_id id = kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod; + if (cpuinfo_has_arm_i8mm() && m > 1) { + id = + kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm; + } + auto kernel_packet = kai_select_groupwise_matmul_ukernel(id); + + const auto& ukernel = kernel_packet.ukernel; + + const size_t mr = ukernel.get_mr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + const size_t n_step = ukernel.get_n_step(); + int64_t total_threads = at::get_num_threads(); + int64_t num_threads_x = 1; + adjust_max_threads(total_threads); + // Split threads 1D only for now + if (n % n_step == 0) { + for (; total_threads > 0; total_threads -= 2) { + if (n % total_threads == 0 && (n / total_threads) % n_step == 0) { + num_threads_x = total_threads; + break; + } + } + } + + const size_t num_n_per_thread = n / num_threads_x; + + const size_t dst_stride = n * sizeof(float); + float* lhs = reinterpret_cast(input.data_ptr()); + uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast(weight.data_ptr()); + + uint8_t* dst_act_mtx_f32 = reinterpret_cast(output.data_ptr()); + const size_t lhs_packed_size = + kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr); + auto lhs_packed = std::make_unique(lhs_packed_size); + + // Quantize and pack the Input + kernel_packet.kai_run_lhs_quant_pack( + m, + k, + mr, + kr, + sr, + 0, + (const float*)lhs, + k * sizeof(float), + (void*)lhs_packed.get()); + + at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) { + for (const auto x : c10::irange(begin, end)) { + matmul_groupwise( + std::ref(ukernel), + m, + num_n_per_thread, + x * num_n_per_thread, + k, + bl, + dst_stride, + lhs_packed.get(), + rhs_packed_mtx_qs4cx, + dst_act_mtx_f32); + } + }); +} + +static void kai_quant_pack_lhs_int4_mm_channelwise( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k) { + // Kernel IDs for GEMM and GEMV + kai_kernel_id gemm_id = + kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm; + kai_kernel_id gemv_id = + kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod; + + // Get the total number of threads available and choose GEMM or GEMV steps + const int64_t total_threads = at::get_num_threads(); + auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id); + if (cpuinfo_has_arm_i8mm()) { + gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id); + } + auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id); + + // Retrieve m_step and n_step values from GEMM and GEMV kernels + const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step(); + const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step(); + const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step(); + const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step(); + // Determine threading and kernel type + ThreadDivision division = get_thread_division( + total_threads, + m, + n, + k, + gemm_m_step, + gemm_n_step, + gemv_m_step, + gemv_n_step); + // Select appropriate kernel packet based on the chosen kernel type + auto& kernel_packet = + division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet; + + // Thread blocking parameters + const size_t mr = kernel_packet.ukernel.get_mr(); + const size_t nr = kernel_packet.ukernel.get_nr(); + const size_t kr = kernel_packet.ukernel.get_kr(); + const size_t sr = kernel_packet.ukernel.get_sr(); + const size_t m_increment = kernel_packet.ukernel.get_m_step(); + const size_t n_per_thread = n / division.num_threads_x; + const size_t m_per_thread = m / division.num_threads_y; + const int64_t num_threads = division.num_threads_y * division.num_threads_x; + const size_t dst_stride = n * sizeof(float); + const size_t lhs_stride = k * sizeof(float); + + const size_t lhs_packed_size = + kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr); + + uint8_t* dst_act_mtx_f32 = reinterpret_cast(output.data_ptr()); + uint8_t* lhs_native_mtx_f32 = reinterpret_cast(input.data_ptr()); + uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast(weight.data_ptr()); + auto lhs_packed = std::make_unique(lhs_packed_size * num_threads); + uint8_t* lhs_packed_base = lhs_packed.get(); + + at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + size_t y = i / division.num_threads_x; + size_t x = i % division.num_threads_x; + uint8_t* lhs_packed_ptr = + lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size; + matmul_channelwise( + std::ref(kernel_packet), + m_increment, + y * m_per_thread, + m_per_thread, + x * n_per_thread, + n_per_thread, + n, + k, + mr, + nr, + kr, + sr, + dst_stride, + lhs_stride, + lhs_native_mtx_f32, + lhs_packed_ptr, + rhs_packed_mtx_qs4cx, + dst_act_mtx_f32); + } + }); +} + +void kai_quant_pack_lhs_int4_mm( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t bl) { + // Prefer Channelwise kernel over Groupwise kernel for conflicting cases + if (bl == k) { + kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( + output, input, weight, m, n, k); + } else if (!(bl % 32) && !(k % bl)) { + kleidiai::kai_quant_pack_lhs_int4_mm_groupwise( + output, input, weight, m, n, k, bl); + } +} +} // namespace at::native::kleidiai +#endif diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.h b/aten/src/ATen/native/kleidiai/kai_kernels.h new file mode 100644 index 000000000000..9b522d7f7705 --- /dev/null +++ b/aten/src/ATen/native/kleidiai/kai_kernels.h @@ -0,0 +1,42 @@ +#pragma once +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +/** + * @brief Rearranges the quantized weight to support kleidiai inference + * @param bl Groupsize for quantization should be multiple of 32 + */ +void kai_pack_int4_rhs( + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Outputs the buffer size for the packed weights + * @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise + */ +size_t kai_pack_rhs_int4_size( + const int64_t n, + const int64_t k, + const int64_t bl); + +/** + * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) + */ +void kai_quant_pack_lhs_int4_mm( + const Tensor& output, + const Tensor& input, + const Tensor& weight, + const int64_t m, + const int64_t n, + const int64_t k, + const int64_t bl); +} // namespace at::native::kleidiai +#endif diff --git a/aten/src/ATen/native/kleidiai/kai_pack.h b/aten/src/ATen/native/kleidiai/kai_pack.h new file mode 100644 index 000000000000..4ff3371ab5e2 --- /dev/null +++ b/aten/src/ATen/native/kleidiai/kai_pack.h @@ -0,0 +1,106 @@ +#pragma once +#include +#include +#include +#include +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +template +void kai_pack_rhs_groupwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k, + const int64_t bl, + const int64_t rhs_stride, + const int64_t scale_stride) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + auto scales_data = scales.const_data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + bl, + (const uint8_t*)(weight_data), + rhs_stride, + bias_ptr, + scales_data, + scale_stride, + weight_packed_data, + 0, + ¶ms); +} + +template +void kai_pack_rhs_channelwise_int4( + T& kernel, + const Tensor& weight_packed, + const Tensor& weight, + const Tensor& scales, + const std::optional& bias, + const int64_t n, + const int64_t k) { + const auto& ukernel = kernel.ukernel; + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + auto weight_packed_data = + reinterpret_cast(weight_packed.data_ptr()); + const auto weight_data = weight.data_ptr(); + const auto scales_data = scales.data_ptr(); + + if (weight_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); + } + + if (scales_data == nullptr) { + AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); + } + + float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; + auto& params = kernel.rhs_pack_params; + + kernel.kai_run_rhs_pack( + /*num_groups=*/1, + n, + k, + nr, + kr, + sr, + (const uint8_t*)(weight_data), + (const float*)(bias_ptr), + (const float*)(scales_data), + weight_packed_data, + 0, + ¶ms); +} + +} // namespace at::native::kleidiai + +#endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp new file mode 100644 index 000000000000..0de198d7dc01 --- /dev/null +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp @@ -0,0 +1,72 @@ +#include + +#if AT_KLEIDIAI_ENABLED() + +namespace at::native::kleidiai { + +// Kernel Mapping - Groupwise +std::unordered_map groupwise_8bit_4bit_kernels = + {{kai_kernel_id:: + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}}, + {kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}}; + +kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( + kai_kernel_id id) { + return groupwise_8bit_4bit_kernels.at(id); +} + +// Kernel Mapping - Channelwise +std::unordered_map channelwise_8bit_4bit_kernels = + {{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}}, + {kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}}; + +kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel( + const kai_kernel_id id) { + return channelwise_8bit_4bit_kernels.at(id); +} +} // namespace at::native::kleidiai +#endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h new file mode 100644 index 000000000000..c0835729f88b --- /dev/null +++ b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h @@ -0,0 +1,144 @@ +#pragma once +#include +#include +#if AT_KLEIDIAI_ENABLED() + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native::kleidiai { + +enum class kai_kernel_id { + matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = + 0, // Groupwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = + 1, // Groupwise 4 bit GEMM + matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = + 2, // Channelwise 4 bit GEMV + matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = + 3 // Channelwise 4 bit GEMM +}; + +// Channelwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { + struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + const uint8_t* rhs, + const float* bias, + const float* scale, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4cxp( + const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp +kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); + +// Groupwise Kernel mapping +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { + struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; + struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params; + size_t (*kai_get_lhs_packed_size)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr); + size_t (*kai_get_rhs_packed_size)( + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + enum kai_datatype scale_dt); + void (*kai_run_lhs_quant_pack)( + size_t m, + size_t k, + size_t mr, + size_t kr, + size_t sr, + size_t m_idx_start, + const float* lhs, + size_t lhs_stride, + void* lhs_packed); + void (*kai_run_rhs_pack)( + size_t num_groups, + size_t n, + size_t k, + size_t nr, + size_t kr, + size_t sr, + size_t bl, + const uint8_t* rhs, + size_t rhs_stride, + const float* bias, + const void* scale, + size_t scale_stride, + void* rhs_packed, + size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); + + kai_matmul_ukernel_f32_qa8dxp_qs4c32p( + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) + : ukernel(kernel), + kai_get_lhs_packed_size( + &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), + kai_get_rhs_packed_size( + &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), + kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), + kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} +}; + +struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( + const kai_kernel_id id); + +} // namespace at::native::kleidiai +#endif diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index ede2e97aead2..d68787e17fed 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4177,6 +4177,14 @@ dispatch: CPU: _weight_int4pack_mm_cpu +- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor + dispatch: + CPU: _dyn_quant_pack_4bit_weight_cpu + +- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor + dispatch: + CPU: _dyn_quant_matmul_4bit_cpu + - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor dispatch: CPU: _weight_int8pack_mm_cpu diff --git a/buckbuild.bzl b/buckbuild.bzl index c6d65dc521a9..3632ed4111a3 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1070,6 +1070,7 @@ def define_buck_targets( ], ) + # TODO: Enable support for KleidiAI bazel build # @lint-ignore BUCKLINT fb_native.genrule( name = "generate_aten_config", @@ -1122,6 +1123,9 @@ def define_buck_targets( "--replace", "@AT_BLAS_USE_CBLAS_DOT@", "AT_BLAS_USE_CBLAS_DOT_FBXPLAT", + "--replace", + "@AT_KLEIDIAI_ENABLED@", + "0", ]), outs = { "Config.h": ["Config.h"], diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 1813f4418a28..279810e10b00 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -152,6 +152,7 @@ endif() set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) +set(AT_KLEIDIAI_ENABLED 0) # setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") @@ -1480,6 +1481,35 @@ if(NOT INTERN_BUILD_MOBILE) message("disabling MKLDNN because USE_MKLDNN is not set") endif() + if(USE_KLEIDIAI) + if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" ) + message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.") + endif() + if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" ) + message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.") + endif() + set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) + set(AT_KLEIDIAI_ENABLED 1) + set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests + set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai") + add_subdirectory(${KLEIDIAI_SRC}) + set(KLEIDIAI_INCLUDE_DIRS + ${KLEIDIAI_SRC}/ + ${KLEIDIAI_SRC}/kai/ + ${KLEIDIAI_SRC}/kai/ukernels/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/ + ) + include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS}) + list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai) + # Recover build options. + set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) + endif() + if(UNIX AND NOT APPLE) include(CheckLibraryExists) # https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830 diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index d8a4bcf21916..9355e01aad98 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -153,6 +153,9 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}") message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}") endif() + if(${USE_KLEIDIAI}) + message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}") + endif() message(STATUS " USE_UCC : ${USE_UCC}") if(${USE_UCC}) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 8f2b2c30aee6..855edd350818 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -97,6 +97,10 @@ else() append_torchlib_if_found(microkernels-prod) endif() + if(@USE_KLEIDIAI@) + append_torchlib_if_found(kleidiai) + endif() + append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc) append_torchlib_if_found(onnx onnx_proto) diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 2fd9277fa814..6d3500c85421 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -203,6 +203,7 @@ torch.backends.openmp .. add anything to the rendered page for now. .. py:module:: torch.backends.quantized .. py:module:: torch.backends.xnnpack +.. py:module:: torch.backends.kleidiai torch.backends.opt_einsum diff --git a/setup.py b/setup.py index 5ed97c6df7b6..4a4d5e51d212 100644 --- a/setup.py +++ b/setup.py @@ -1221,6 +1221,7 @@ def main(): "include/ATen/native/cuda/*.cuh", "include/ATen/native/hip/*.h", "include/ATen/native/hip/*.cuh", + "include/ATen/native/kleidiai/*.h", "include/ATen/native/mps/*.h", "include/ATen/native/mkldnn/xpu/*.h", "include/ATen/native/mkldnn/xpu/detail/*.h", diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 883399f855cc..46ab3a57d90d 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -92,6 +92,8 @@ aten::_dimI aten::_dimV aten::_dirichlet_grad aten::_dirichlet_grad.out +aten::_dyn_quant_matmul_4bit +aten::_dyn_quant_pack_4bit_weight aten::_efficient_attention_backward aten::_efficient_attention_forward aten::_efficientzerotensor diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e26936cf17e2..de66b58caba6 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -75,6 +75,7 @@ from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_dtype import all_types, get_all_dtypes from torch.testing._internal.common_quantization import ( _dynamically_quantize_per_channel, + _group_quantize_tensor_symmetric, ) from torch.testing._internal.common_utils import ( DeterministicGuard, @@ -2223,6 +2224,85 @@ class CommonTemplate: b_int8pack, b_scales = convert_weight_to_int8pack(b) self.common(fn, (a, b_int8pack, b_scales, c)) + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") + def test__dyn_quant_pack_4bit_weight(self): + q_group = 32 + k = 128 + n = 128 + + torch.manual_seed(1) + b = torch.rand((k, n), dtype=torch.float32) + in_features = b.size(0) + out_features = b.size(1) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(b, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + return b_int4pack + + self.common(fn, (b, in_features, out_features)) + + @xfail_if_triton_cpu + @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") + @skipIfRocm + @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") + def test__dyn_quant_matmul_4bit(self): + q_group = 32 + m = 32 + k = 128 + n = 128 + + torch.manual_seed(1) + a = torch.rand((m, k), dtype=torch.float32) + b = torch.rand((k, n), dtype=torch.float32) + in_features = b.size(0) + out_features = b.size(1) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def fn(a, q_group, in_features, out_features): + b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) + res = torch._dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + return res + + self.common(fn, (a, q_group, in_features, out_features)) + def test_expanded_reduction(self): def fn(x, y): z = x * y diff --git a/test/test_linalg.py b/test/test_linalg.py index c6fb81c86a37..c40ea267d2b8 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -33,7 +33,8 @@ from torch.testing._internal.common_dtype import ( ) from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU -from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel +from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \ + _group_quantize_tensor_symmetric from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.distributions.binomial import Binomial import torch.backends.opt_einsum as opt_einsum @@ -908,7 +909,6 @@ class TestLinalg(TestCase): torch.randn((3, 52, 52), device=device, dtype=dtype), torch.randn((4, 2, 26, 26), device=device, dtype=dtype)) - ops = (torch.det, torch.Tensor.det, torch.linalg.det) for t in tensors: @@ -1437,7 +1437,6 @@ class TestLinalg(TestCase): continue run_test_case(make_arg(shape), ord, dim, keepdim) - @onlyCUDA @dtypes(torch.bfloat16, torch.float16) def test_norm_fused_type_promotion(self, device, dtype): @@ -4343,7 +4342,6 @@ class TestLinalg(TestCase): triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), upper, unitriangular, transpose) - @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -4464,7 +4462,6 @@ class TestLinalg(TestCase): self.assertTrue("An output with one or more elements was resized" in str(w[0].message)) self.assertTrue("An output with one or more elements was resized" in str(w[1].message)) - def check_single_matmul(self, x, y): def assertEqual(answer, expected): @@ -5700,7 +5697,6 @@ class TestLinalg(TestCase): else: self.assertEqual(B_, X_ @ A) - sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) # Non pivoting just implemented for CUDA @@ -5733,7 +5729,6 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): f(torch.empty(1, 2, 2), pivot=False) - @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @@ -5761,7 +5756,6 @@ class TestLinalg(TestCase): for b, n in shapes: yield make_arg((b, n, n)), make_arg((b, n, rhs)) - for A, B in gen_matrices(): LU, pivots = torch.linalg.lu_factor(A) for backend in backends: @@ -5776,7 +5770,6 @@ class TestLinalg(TestCase): else: self.assertEqual(B_left, X @ A_adj) - @onlyCPU @dtypes(*floating_and_complex_types()) def test_linalg_lu_cpu_errors(self, device, dtype): @@ -5817,7 +5810,6 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): torch.lu_unpack(LU, pivots) - # Rectangular tests sample = torch.randn(2, 3, 5, device=device, dtype=dtype) B = torch.randn(2, 3, 5, device=device, dtype=dtype) @@ -5834,7 +5826,6 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): torch.lu_unpack(LU, pivots) - @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.double) @@ -6443,7 +6434,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) self.assertEqual(out, y_ref) - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA def test_matmul_45724(self, device): @@ -6616,7 +6606,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: torch._int_mm(a_int8, b_int8, out=c_int32_result) self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) - @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyNativeDeviceTypes @@ -6679,7 +6668,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) - @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyNativeDeviceTypes @@ -6732,6 +6720,168 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") + @onlyNativeDeviceTypes + @parametrize("k", [64, 256]) + @parametrize("n", [32, 48, 64, 128]) + def test__dyn_quant_pack_4bit_weight(self, device, k, n): + # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead + # Weight shape is [K x N] + if self.device_type == "cuda": + self.skipTest("CUDA Backend is unsupported") + + torch.manual_seed(1) + block_size = 32 + b = torch.rand((k, n), dtype=torch.bfloat16, device=device) + in_features = b.size(0) + out_features = b.size(1) + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=block_size + ) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features + ) + b_int4pack_meta = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features + ) + self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape) + + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") + @onlyNativeDeviceTypes + @parametrize("m", [1, 32]) + @parametrize("k", [64, 128]) + @parametrize("n", [4096, 11008]) + def test__dyn_quant_matmul_4bit(self, device, m, k, n): + if self.device_type == "cuda": + self.skipTest("CUDA is unsupported") + + q_group = 32 + + torch.manual_seed(1) + a_float32 = torch.rand((m, k), dtype=torch.float32, device=device) + b_float32 = torch.rand((k, n), dtype=torch.float32, device=device) + in_features = b_float32.size(0) + out_features = b_float32.size(1) + + def dyn_quant_pack_4bit_weight(b, in_features, out_features): + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + + return b_int4pack, b_scales_and_zeros + + def dyn_quant_matmul_4bit( + a, b_int4pack, q_group, in_features, out_features + ): + return torch._dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + + b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight( + b_float32, in_features, out_features + ) + + dtypes = [torch.float32] + + for dtype in dtypes: + a = a_float32.to(dtype=dtype) + b = b_float32.to(dtype=dtype) + ref = torch.mm(a, b) + res = dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + elementwise_diff = (res - ref).abs() + elementwise_relative_error = elementwise_diff / ref.abs().clamp( + min=torch.finfo(ref.dtype).eps + ) + all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06) + self.assertTrue( + all_elements_within_threshold, "Some elements have error >= 0.06" + ) + + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") + @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") + @onlyNativeDeviceTypes + @parametrize("m", [1, 32]) + @parametrize("k", [64, 128]) + @parametrize("n", [4096, 11008]) + def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n): + if self.device_type == "cuda": + self.skipTest("CUDA is unsupported") + + q_group = 32 + + torch.manual_seed(1) + a_float32 = torch.rand((m, k), dtype=torch.float32, device=device) + b_float32 = torch.rand((k, n), dtype=torch.float32, device=device) + in_features = b_float32.size(0) + out_features = b_float32.size(1) + + b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( + b_float32, n_bit=4, groupsize=q_group + ) + + if q_group == in_features: + b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float) + else: + b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16) + + @torch.compile + def dyn_quant_matmul_4bit( + a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features + ): + b_int4pack = torch._dyn_quant_pack_4bit_weight( + b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features + ) + return torch._dyn_quant_matmul_4bit( + a, + b_int4pack, + q_group, + in_features, + out_features, + ) + + res = dyn_quant_matmul_4bit( + a_float32, + b_uint8, + b_scales_and_zeros, + q_group, + in_features, + out_features, + ) + ref = torch.mm(a_float32, b_float32) + + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + elementwise_diff = (res - ref).abs() + elementwise_relative_error = elementwise_diff / ref.abs().clamp( + min=torch.finfo(ref.dtype).eps + ) + all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06) + self.assertTrue( + all_elements_within_threshold, "Some elements have error >= 0.06" + ) + @onlyCPU @parametrize("m", [32, 64]) @parametrize("k", [32, 64]) @@ -8663,8 +8813,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.assertEqual(ck_out, cpu_out) - - def test_permute_matmul(self): a = torch.ones([2, 5, 24, 24]) b = torch.ones([3, 2, 5, 24, 24]) @@ -8744,7 +8892,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: ref = alpha * A @ B + beta * C self.assertEqual(rc, ref) - @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 1e-4}) def test_1_sized_with_0_strided(self, device, dtype): diff --git a/third_party/kleidiai b/third_party/kleidiai new file mode 160000 index 000000000000..202603f38a9d --- /dev/null +++ b/third_party/kleidiai @@ -0,0 +1 @@ +Subproject commit 202603f38a9df9d2ded89f12b41ded621c71d4ea diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 208f1d97a978..f7ac713849e7 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1299,6 +1299,7 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a has_openmp: _bool has_mkl: _bool +_has_kleidiai: _bool _has_mps: _bool has_lapack: _bool _has_cuda: _bool diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 7fa91fb308bc..ea46c0c86d50 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1372,6 +1372,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._dim_arange", "torch._dirichlet_grad", "torch._disable_functionalization", + "torch._dyn_quant_matmul_4bit", + "torch._dyn_quant_pack_4bit_weight", "torch._efficientzerotensor", "torch._embedding_bag_forward_only", "torch._embedding_bag", diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 07778d6346ec..1862140b1265 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -94,3 +94,6 @@ def register_woq_mm_ops() -> None: return autotune_select_algorithm( "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout ) + + lowering.make_fallback(aten._dyn_quant_matmul_4bit) + lowering.make_fallback(aten._dyn_quant_pack_4bit_weight) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index b201ab1a60d5..c69017747997 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3344,6 +3344,161 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros): return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) +def kai_roundup(a: int, b: int) -> int: + return ((a + b - 1) // b) * b + + +def get_kai_packed_weight_size(n_bits, N, K, groupsize): + if n_bits == 4: + if groupsize == K: # channelwise + # dotprod params only [1x8x32_neon_dotprod] + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 # sizeof(int32_t) + kai_num_bytes_multiplier_rhs = 4 # sizeof(float) + kai_num_bytes_bias = 4 # sizeof(float) + + def kai_k_roundedup(k, kr, sr): + # Since we pack a float and int32 value at the end of the row, + # we must make sure that k is a multiple of 4 for alignment + kr_sr_roundedup4 = kai_roundup(kr * sr, 4) + return kai_roundup(k, kr_sr_roundedup4) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ): + k_internal = kai_k_roundedup(k, kr, sr) + + assert (k_internal % 2) == 0, "k_internal must be even" + + return nr * ( + (k_internal // 2) + + kai_num_bytes_multiplier_rhs + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + n, k, nr, kr, sr + ): + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + k, nr, kr, sr + ) + ) + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + N, K, kai_nr, kai_kr, kai_sr + ) + elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise + kai_nr = 8 + kai_kr = 16 + kai_sr = 2 + kai_num_bytes_sum_rhs = 4 + kai_num_bytes_bias = 4 + kai_nr_multiple_of = 4 + kai_bl_multiple_of = 32 + + def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + n, k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + num_rows = kai_roundup(n, nr) // nr + + return ( + num_rows + * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ) + ) + + def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + k, nr, kr, sr, bl + ): + assert (bl % kr) == 0 + assert (nr % kai_nr_multiple_of) == 0 + assert (bl % kai_bl_multiple_of) == 0 + + # kr and sr are unused in the calculation + num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes() + num_blocks_per_row = kai_num_blocks_per_row(k, bl) + num_bytes_per_block = kai_num_bytes_per_block( + bl, num_bytes_multiplier_rhs + ) + + return nr * ( + (num_bytes_per_block * num_blocks_per_row) + + kai_num_bytes_sum_rhs + + kai_num_bytes_bias + ) + + # This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype + # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55 + def kai_get_bf16_datatype_size_in_bytes(): + return 2 # 2 bytes + + def kai_num_blocks_per_row(k, bl): + assert (bl % kai_bl_multiple_of) == 0 + return kai_roundup(k, bl) // bl + + def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): + assert (bl % kai_bl_multiple_of) == 0 + return (bl // 2) + num_bytes_multiplier_rhs + + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( + N, K, kai_nr, kai_kr, kai_sr, groupsize + ) + + +@register_meta([aten._dyn_quant_pack_4bit_weight]) +def meta__dyn_quant_pack_4bit_weight( + weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features +): + torch._check( + weights.dtype is torch.uint8, + lambda: f"expected w to be uint8, got {weights.dtype}", + ) + if torch.backends.kleidiai.is_available() and ( + (block_size == in_features and scales_zeros.dtype == torch.float) + or ( + block_size < in_features + and block_size % 32 == 0 + and in_features % block_size == 0 + and scales_zeros.dtype == torch.bfloat16 + ) + ): + packed_weight_size = get_kai_packed_weight_size( + 4, out_features, in_features, block_size + ) + return weights.new_empty(int(packed_weight_size), dtype=torch.uint8) + packed_weight_size = weights.numel() + scales_zeros.numel() + return weights.new_empty(packed_weight_size, dtype=torch.float) + + +@register_meta([aten._dyn_quant_matmul_4bit]) +def meta__dyn_quant_matmul_4bit( + inp, + packed_weights, + block_size, + in_features, + out_features, +): + torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") + torch._check( + inp.dtype in [torch.float32], + lambda: f"expected input to be f32, got {inp.dtype}", + ) + M = inp.size(0) + return inp.new_empty(M, out_features, dtype=inp.dtype) + + @register_meta([aten._weight_int8pack_mm]) def meta__weight_int8pack_mm(x, w, q_scales): torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index 301735188869..90166913e324 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -62,6 +62,7 @@ from torch.backends import ( cuda as cuda, cudnn as cudnn, cusparselt as cusparselt, + kleidiai as kleidiai, mha as mha, mkl as mkl, mkldnn as mkldnn, diff --git a/torch/backends/kleidiai/__init__.py b/torch/backends/kleidiai/__init__.py new file mode 100644 index 000000000000..1a681b77ef58 --- /dev/null +++ b/torch/backends/kleidiai/__init__.py @@ -0,0 +1,7 @@ +# mypy: allow-untyped-defs +import torch + + +def is_available(): + r"""Return whether PyTorch is built with KleidiAI support.""" + return torch._C._has_kleidiai diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 4b083fb7bde4..93dd040c24a9 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1939,6 +1939,8 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE( set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); + ASSERT_TRUE( + set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False)); ASSERT_TRUE( set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 5d8c34483270..b6d39364748c 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -18,6 +18,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 4e6723e3f6c7..7403224b15c2 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -498,6 +498,39 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): return out, scales_and_zeros +def _group_quantize_tensor_symmetric( + w, n_bit=4, groupsize=32 +): + # W is of shape [K x N] + # We transpose W as Quantization is applied on [N x K] + w = w.transpose(0, 1).contiguous() + assert w.dim() == 2 + assert groupsize > 1 + assert w.shape[-1] % groupsize == 0 + # Calculate scale and zeros + to_quant = w.reshape(-1, groupsize) + max_val = to_quant.abs().amax(dim=1, keepdim=True) + eps = torch.finfo(max_val.dtype).eps + max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7 + scales = max_val.clamp(min=eps) / max_int + zeros = torch.zeros_like(scales) + + # Quantize the weight + scales = scales.to(torch.float32).reshape(w.shape[0], -1) + zeros = zeros.to(torch.float32).reshape(w.shape[0], -1) + scales = scales.reshape(-1, 1) + zeros = zeros.reshape(-1, 1) + max_int = 2**n_bit - 1 + w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int) + # We pack 2 signed int4 values in unsigned uint8 container. + # This reduces the weight size by half and improves load perf + out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) + + scales_and_zeros = scales.squeeze().contiguous() + + return out_uint8, scales_and_zeros + + def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py # default setup for affine quantization of activations @@ -530,7 +563,6 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): return quant, scales.to(x_dtype), zero_points - # QuantizationTestCase used as a base class for testing quantization on modules class QuantizationTestCase(TestCase): def setUp(self): diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index 7e40b020ad8f..b9cf12025556 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -45,6 +45,8 @@ inductor_fallback_ops = { "aten.cummin.default", "aten.cumprod.default", "aten.cumsum.default", + "aten._dyn_quant_matmul_4bit.default", + "aten._dyn_quant_pack_4bit_weight.default", "aten._efficient_attention_backward.default", "aten._efficient_attention_forward.default", "aten._efficientzerotensor.default",