diff --git a/.gitmodules b/.gitmodules index 3813fc2b96b3..36d5becb57c3 100644 --- a/.gitmodules +++ b/.gitmodules @@ -131,6 +131,3 @@ path = third_party/composable_kernel url = https://github.com/ROCm/composable_kernel.git branch = develop -[submodule "third_party/kleidiai"] - path = third_party/kleidiai - url = https://git.gitlab.arm.com/kleidi/kleidiai.git diff --git a/BUILD.bazel b/BUILD.bazel index df46835f363e..0e3bef24ebac 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -257,7 +257,6 @@ filegroup( # target that generates these sources... ) -# TODO: Enable support for KleidiAI bazel build header_template_rule( name = "aten_src_ATen_config", src = "aten/src/ATen/Config.h.in", @@ -277,7 +276,6 @@ header_template_rule( "@AT_PARALLEL_NATIVE@": "1", "@AT_BLAS_F2C@": "0", "@AT_BLAS_USE_CBLAS_DOT@": "1", - "@AT_KLEIDIAI_ENABLED@": "0", }, ) diff --git a/CMakeLists.txt b/CMakeLists.txt index 93793d22790b..8e09a57ffa70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -377,8 +377,6 @@ cmake_dependent_option( cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF) cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" OFF "USE_CUDA" OFF) -cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON - "CPU_AARCH64" OFF) option(USE_MIMALLOC "Use mimalloc" OFF) # Enable third party mimalloc library to improve memory allocation performance @@ -420,8 +418,6 @@ endif() if(WIN32) set(USE_TENSORPIPE OFF) message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF") - set(USE_KLEIDIAI OFF) - message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF") if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT}) find_library( @@ -671,9 +667,6 @@ if(ANDROID message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND") set(BUILD_LAZY_TS_BACKEND OFF) - set(USE_KLEIDIAI OFF) - message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF") - # Set -ffunction-sections and -fdata-sections so that each method has its own # text section. This allows the linker to remove unused section when the flag # -Wl,-gc-sections is provided at link time. diff --git a/WORKSPACE b/WORKSPACE index ae7c0644e203..ac06b6bdc5d9 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -309,12 +309,6 @@ local_repository( path = "third_party/gemmlowp/gemmlowp", ) -local_repository( - name = "kleidiai", - path = "third_party/kleidiai", - repo_mapping = {"@com_google_googletest": "@com_google_benchmark"}, -) - ### Unused repos start # `unused` repos are defined to hide bazel files from submodules of submodules. diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 9d246513f52c..5473468f627f 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -219,10 +219,6 @@ endif() # XNNPACK file(GLOB native_xnnpack "native/xnnpack/*.cpp") -# KLEIDIAI -file(GLOB native_kleidiai "native/kleidiai/*.cpp") -file(GLOB native_kleidiai_h "native/kleidiai/*.h") - # Add files needed from jit folders append_filelist("jit_core_headers" ATen_CORE_HEADERS) append_filelist("jit_core_sources" ATen_CORE_SRCS) @@ -252,10 +248,6 @@ endif() if(AT_MKL_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp}) endif() -if(AT_KLEIDIAI_ENABLED) - set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai}) - include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS}) -endif() if(AT_MKLDNN_ENABLED) set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp}) endif() @@ -645,7 +637,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake" set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS}) if(NOT INTERN_BUILD_MOBILE) - list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) + list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h}) # Metal if(USE_PYTORCH_METAL_EXPORT) # Add files needed from exporting metal models(optimized_for_mobile) diff --git a/aten/src/ATen/Config.h.in b/aten/src/ATen/Config.h.in index c22e15a52aa2..fdd2ac2bc5f7 100644 --- a/aten/src/ATen/Config.h.in +++ b/aten/src/ATen/Config.h.in @@ -19,4 +19,3 @@ #define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@ #define AT_BLAS_F2C() @AT_BLAS_F2C@ #define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@ -#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@ diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index db5380ac961d..ad29671c0287 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -428,10 +428,6 @@ bool Context::hasMKLDNN() { #endif } -bool Context::hasKleidiAI() { - return AT_KLEIDIAI_ENABLED(); -} - bool Context::hasOpenMP() { #ifdef _OPENMP return true; diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index f986774e3fa9..5f2097206e01 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -119,7 +119,6 @@ class TORCH_API Context { static bool hasOpenMP(); static bool hasMKL(); - static bool hasKleidiAI(); static bool hasLAPACK(); static bool hasMKLDNN(); static bool hasMAGMA() { @@ -548,10 +547,6 @@ inline bool hasMKL() { return globalContext().hasMKL(); } -inline bool hasKleidiAI() { - return globalContext().hasKleidiAI(); -} - inline bool hasLAPACK() { return globalContext().hasLAPACK(); } diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9cb0c652c573..460585c0590c 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -33,8 +33,6 @@ #include #include #include -#include -#include #include #include #include @@ -3431,8 +3429,6 @@ Tensor kron(const Tensor& self, const Tensor& other) { DEFINE_DISPATCH(weight_to_int4pack_stub); DEFINE_DISPATCH(int4pack_mm_stub); DEFINE_DISPATCH(int8pack_mm_stub); -DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub); -DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub); Tensor _convert_weight_to_int4pack_cpu( const Tensor& in, @@ -3496,69 +3492,6 @@ Tensor _weight_int4pack_mm_cpu( return C; } -Tensor _dyn_quant_pack_4bit_weight_cpu( - const Tensor& weights, - const Tensor& scales_zeros, - const std::optional& bias, - const int64_t block_size, - const int64_t in_features, - const int64_t out_features) { - TORCH_CHECK( - weights.dtype() == at::kByte, __func__, " : expect weight to be kByte."); - TORCH_CHECK( - block_size == in_features || - (!(block_size % 32) && !(in_features % block_size)), - __func__, - ": Group size should be multiple of 32, in_features [", - in_features, - "]. Provided ", - block_size); - Tensor packed_weights = - at::empty(weights.sizes(), weights.options().dtype(at::kByte)); - dyn_quant_pack_4bit_weight_stub( - kCPU, - packed_weights, - weights, - scales_zeros, - bias, - out_features, - in_features, - block_size); - return packed_weights; -} - -Tensor _dyn_quant_matmul_4bit_cpu( - const Tensor& inp, - const Tensor& packed_weights, - const int64_t block_size, - const int64_t in_features, - const int64_t out_features) { - auto M = inp.size(0); - TORCH_CHECK( - inp.dtype() == kFloat, - __func__, - " : expect input to be 32-bit float tensor."); - TORCH_CHECK( - block_size == in_features || - (!(block_size % 32) && !(in_features % block_size)), - __func__, - ": Group size should be multiple of 32, in_features [", - in_features, - "]. Provided ", - block_size); - auto output = at::empty({M, out_features}, inp.options()); - dyn_quant_matmul_4bit_stub( - kCPU, - output, - inp, - packed_weights, - M, - out_features, - in_features, - block_size); - return output; -} - Tensor _weight_int8pack_mm_cpu( const Tensor& A, const Tensor& B, diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index c8e0b8e86793..11a34eefb95a 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -8,19 +8,8 @@ #include #include #include -#include #include - -#ifndef AT_PER_OPERATOR_HEADERS -#include -#else -#include -#endif - -#if AT_KLEIDIAI_ENABLED() -#include -#include -#endif +#include #if (defined(_WIN32) || defined(_WIN64)) #define RESTRICT __restrict @@ -773,457 +762,10 @@ void int4pack_mm_kernel( } } -#if AT_KLEIDIAI_ENABLED() -bool can_use_kleidiai( - const at::Tensor& scales_zeros, - const int64_t K, - const int64_t block_size) { - bool ret = false; - if (cpuinfo_has_arm_neon_dot()) { - // The Groupwise kernel requires BFloat16 Scales and Channelwise kernel - // requires Float32 Scales. If not provided, we will use fallback - // implementation. - if ((block_size == K && scales_zeros.dtype() == at::kFloat) || - ((block_size < K && !(block_size % 32) && !(K % block_size)) && - scales_zeros.dtype() == at::kBFloat16)) { - ret = true; - } - } - return ret; -} -#endif - -/** - * The Int4 quantized weights must be represented as a uint8 tensor - * For matrix multiplication with a weight shape of (N x K) - * the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2]. - * - * For KleidiAI weight packing, the scales, biases, and Int4 quantized - * weights are packed into a single `packed_weights` structure, optimized for - * Arm instructions. - * - * In the fallback reference kernel, no special packing is required for - * Int4 quantized weights. - * - * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires - * Float32 Scales. If not provided, we will use fallback implementation. - */ -void dyn_quant_pack_4bit_weight_kernel( - Tensor& packed_weights, - const Tensor& weights, - const Tensor& scales_zeros, - const std::optional& bias, - const int64_t N, - const int64_t K, - const int64_t block_size) { -#if AT_KLEIDIAI_ENABLED() - if (can_use_kleidiai(scales_zeros, K, block_size)) { - const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); - packed_weights.resize_({weight_packed_size}); - kleidiai::kai_pack_int4_rhs( - packed_weights, weights, scales_zeros, bias, N, K, block_size); - } else -#endif - { - TORCH_CHECK( - bias.has_value() == 0, - __func__, - " : Bias is unsupported in reference implementation"); - packed_weights = packed_weights.to(kFloat); - auto weight_reshaped = weights.view({-1}).to(kFloat); - auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat); - auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0); - packed_weights.resize_(res.sizes()).copy_(res); - } -} - -static void ref_dyn_quant_matmul_4bit_channelwise_kernel( - size_t m, - size_t n, - size_t k, - const float* lhs_f32, - const uint8_t* rhs_qs4cx, - const float* rhs_scales_f32, - float* dst_f32, - float scalar_min, - float scalar_max) { - const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float)); - - auto lhs_qa8dx_buffer = std::make_unique(input_size_8bit); - uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get(); - - // Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in - // required format for matmul - auto input_quant_pack_8bit_channelwise = - [&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) { - const size_t dst_stride = - (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); - - const size_t lhs_qa8dx_stride = k; - - for (size_t m_idx = 0; m_idx < m; ++m_idx) { - const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride; - - float max0 = -FLT_MAX; - float min0 = FLT_MAX; - - // Find min/max for each channel - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; - - max0 = (std::max)(src0_0, max0); - min0 = (std::min)(src0_0, min0); - } - - // Maximum/minimum int8 values - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; - - const float rmin0 = (std::min)(0.0f, min0); - const float rmax0 = (std::max)(0.0f, max0); - - const float scale0 = - rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0); - - // Reciprocal to quantize - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; - - const float zero_point_from_min_error0 = qmin + descaled_min0; - const float zero_point_from_max_error0 = qmax + descaled_max0; - - float zero_point0 = - zero_point_from_min_error0 + zero_point_from_max_error0 > 0 - ? qmin - descaled_min0 - : qmax - descaled_max0; - - zero_point0 = (std::max)(zero_point0, qmin); - zero_point0 = (std::min)(zero_point0, qmax); - - // Round to nearest integer - const int32_t nudged_zero_point0 = lrintf(zero_point0); - - int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride; - - // LHS offset at the beginning of the row - *((float*)(dst_ptr)) = recip_scale0; - dst_ptr += sizeof(float); - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; - dst_ptr += sizeof(int32_t); - - // Quantize the channels - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; - - // Scale the values - int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); - - v0_s32 = v0_s32 + nudged_zero_point0; - v0_s32 = (std::max)(v0_s32, static_cast(INT8_MIN)); - v0_s32 = (std::min)(v0_s32, static_cast(INT8_MAX)); - dst_ptr[0] = (int8_t)v0_s32; - dst_ptr += sizeof(int8_t); - } - } - }; - - // Dynamically Quantize the float32 input to 8 bit assymetric - input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx); - - const size_t lhs_stride = - k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t); - - const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2); - - for (size_t m_idx = 0; m_idx < m; ++m_idx) { - const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride; - - for (size_t n_idx = 0; n_idx < n; ++n_idx) { - // Main f32 accumulator - int32_t iacc = 0; - - const int8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride; - - // Get the LHS quantization parameters stored at the - // beginning of each row - const float lhs_scale = *(const float*)lhs_ptr; - lhs_ptr += sizeof(float); - - const int32_t lhs_offset = *(const int32_t*)lhs_ptr; - lhs_ptr += sizeof(int32_t); - - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - // Get the LHS values - const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; - - // Get the RHS values - const uint8_t rhs_byte = rhs_ptr[0]; - - // Unpack the RHS values - int32_t rhs_v0 = 0; - if ((k_idx % 2) == 0) { - rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8); - } else { - rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8); - } - - iacc += lhs_v0 * rhs_v0; - iacc += lhs_offset * rhs_v0; - - lhs_ptr += 1; - - // Increment only when k_idx is not a multiple of 2 - rhs_ptr += k_idx % 2; - } - - // Get the RHS scale - const float rhs_scale = rhs_scales_f32[n_idx]; - - float main_acc = iacc * rhs_scale; - - main_acc = main_acc * lhs_scale; - - // Clamp (min-max) operation - main_acc = (std::max)(main_acc, scalar_min); - main_acc = (std::min)(main_acc, scalar_max); - - dst_f32[0] = main_acc; - dst_f32 += 1; - } - } -} - -static void ref_dyn_quant_matmul_4bit_groupwise_kernel( - size_t m, - size_t n, - size_t k, - size_t bl, - const float* lhs_f32, - const uint8_t* rhs_qs4c32, - const float* rhs_scales_fp32, - float* dst_f32, - float scalar_min, - float scalar_max) { - // Lambda for LHS quantization - auto lhs_quant_pack = [&](size_t m, - size_t k, - const float* lhs_f32, - int8_t* lhs_qa8dx) { - const size_t dst_stride = - (k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t)); - - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const float* src_ptr = lhs_f32 + row_idx * k; - - float max0 = -FLT_MAX; - float min0 = FLT_MAX; - - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; - max0 = (std::max)(src0_0, max0); - min0 = (std::min)(src0_0, min0); - } - - const float qmin = (float)INT8_MIN; - const float qmax = (float)INT8_MAX; - - const float rmin0 = (std::min)(0.0f, min0); - const float rmax0 = (std::max)(0.0f, max0); - const float scale0 = - (rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0); - const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f; - - const float descaled_min0 = rmin0 * scale0; - const float descaled_max0 = rmax0 * scale0; - - float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0) - ? qmin - descaled_min0 - : qmax - descaled_max0; - - zero_point0 = (std::max)(zero_point0, qmin); - zero_point0 = (std::min)(zero_point0, qmax); - const int32_t nudged_zero_point0 = lrintf(zero_point0); - - int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride; - - *((float*)(dst_ptr)) = recip_scale0; - dst_ptr += sizeof(float); - *((int32_t*)(dst_ptr)) = -nudged_zero_point0; - dst_ptr += sizeof(int32_t); - - for (size_t k_idx = 0; k_idx < k; ++k_idx) { - const float src0_0 = src_ptr[k_idx]; - int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0)); - v0_s32 = (std::max)( - (std::min)( - v0_s32 + nudged_zero_point0, static_cast(INT8_MAX)), - static_cast(INT8_MIN)); - dst_ptr[0] = (int8_t)v0_s32; - dst_ptr += sizeof(int8_t); - } - } - }; - - auto lhs_qa8dx_buffer = std::make_unique( - m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS - int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get(); - // Quantize and pack LHS - lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx); - - const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl; - const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t); - const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2; - - for (size_t row_idx = 0; row_idx < m; ++row_idx) { - const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride; - - for (size_t col_idx = 0; col_idx < n; ++col_idx) { - float main_acc = 0.0f; - const int8_t* lhs_ptr = lhs_ptr_start; - const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride; - - const float lhs_scale = *(const float*)lhs_ptr; - lhs_ptr += sizeof(float); - const int32_t lhs_offset = *(const int32_t*)lhs_ptr; - lhs_ptr += sizeof(int32_t); - - for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) { - const float rhs_scale = - rhs_scales_fp32[block_idx + col_idx * num_blocks_row]; - int32_t iacc = 0; - - for (size_t i = 0; i < bl; ++i) { - const size_t k_idx = block_idx * bl + i; - if (k_idx >= k) { - break; - } - - const int32_t lhs_v0 = (int32_t)lhs_ptr[0]; - const uint8_t rhs_byte = rhs_ptr[0]; - int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8) - : (((int32_t)(rhs_byte >> 4)) - 8); - - iacc += lhs_v0 * rhs_v0; - iacc += lhs_offset * rhs_v0; - - lhs_ptr += 1; - rhs_ptr += (k_idx % 2); - } - - main_acc += iacc * rhs_scale; - } - - main_acc = main_acc * lhs_scale; - main_acc = (std::max)(main_acc, scalar_min); - main_acc = (std::min)(main_acc, scalar_max); - - dst_f32[0] = main_acc; - dst_f32 += 1; - } - } -} - -/** - * Dynamic Input Quant 4 bit weights matmul execution flow - (INT4 Weights + FP scales + FP32 Bias) - FP32 Input Packed Buffer - | | - Quantize Cast - to INT8 to INT8 - | | - v v - INT8 Input INT8 Weights - \ / - \ / - \ / - INT8 Matrix Multiplication - | - v - FP32 Dequantized and Accumulate in FP32 - | - v - FP32 Final Output - - * The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires - * Float32 Scales. If not provided, we will use fallback implementation. - */ -void dyn_quant_matmul_4bit_kernel( - const Tensor& output, - const Tensor& inp, - const Tensor& packed_weights, - const int64_t M, - const int64_t N, - const int64_t K, - const int64_t block_size) { -#if AT_KLEIDIAI_ENABLED() - const int64_t weight_packed_size = - kleidiai::kai_pack_rhs_int4_size(N, K, block_size); - if (weight_packed_size == packed_weights.numel()) { - // KleidiAI interface intenally handles the Channelwise and groupwise - // distinction - kleidiai::kai_quant_pack_lhs_int4_mm( - output, inp, packed_weights, M, N, K, block_size); - } else -#endif - { - float* lhs_f32 = reinterpret_cast(inp.data_ptr()); - const auto weights_size = N * K / 2; - // The weights needs to be in uint8_t data type after quantization - auto extracted_weights = - (packed_weights.narrow(0, 0, weights_size)).to(kByte); - auto float32_scales = - (packed_weights.narrow( - 0, weights_size, packed_weights.size(0) - weights_size)) - .to(kFloat); - uint8_t* rhs_4bit = - reinterpret_cast(extracted_weights.data_ptr()); - float* rhs_scales_f32 = reinterpret_cast(float32_scales.data_ptr()); - float* dst_f32 = reinterpret_cast(output.data_ptr()); - if (block_size == K) { - ref_dyn_quant_matmul_4bit_channelwise_kernel( - M, - N, - K, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); - } else if (!(block_size % 32) && !(K % block_size)) { - ref_dyn_quant_matmul_4bit_groupwise_kernel( - M, - N, - K, - block_size, - lhs_f32, - rhs_4bit, - rhs_scales_f32, - dst_f32, - -FLT_MAX, - FLT_MAX); - } else { - TORCH_CHECK( - block_size == K || (!(block_size % 32) && !(K % block_size)), - __func__, - ": Group size should be multiple 32 or in_features [", - K, - "]. Provided ", - block_size); - } - } -} - } // anonymous namespace ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel) ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel) -REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel) -REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel) } // at::native C10_DIAGNOSTIC_POP() diff --git a/aten/src/ATen/native/cpu/int_mm_kernel.h b/aten/src/ATen/native/cpu/int_mm_kernel.h index b5479e45fd5f..1131aa9b53c9 100644 --- a/aten/src/ATen/native/cpu/int_mm_kernel.h +++ b/aten/src/ATen/native/cpu/int_mm_kernel.h @@ -5,34 +5,12 @@ namespace at::native { -using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&); -using int4pack_mm_fn = - void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); -using int8pack_mm_fn = - void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); -using dyn_quant_pack_4bit_weight_fn = void (*)( - Tensor&, - const Tensor&, - const Tensor&, - const std::optional& bias, - const int64_t, - const int64_t, - const int64_t); -using dyn_quant_matmul_4bit_fn = void (*)( - const Tensor&, - const Tensor&, - const Tensor&, - const int64_t, - const int64_t, - const int64_t, - const int64_t); +using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&); +using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&); +using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&); DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub) DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub) DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub) -DECLARE_DISPATCH( - dyn_quant_pack_4bit_weight_fn, - dyn_quant_pack_4bit_weight_stub) -DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub) } // namespace at::native diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.cpp b/aten/src/ATen/native/kleidiai/kai_kernels.cpp deleted file mode 100644 index 04036a59c477..000000000000 --- a/aten/src/ATen/native/kleidiai/kai_kernels.cpp +++ /dev/null @@ -1,482 +0,0 @@ -#include -#include -#include - -#include - -#include -#include -#include -#include -#if AT_KLEIDIAI_ENABLED() -#include - -namespace at::native::kleidiai { - -void kai_pack_int4_rhs( - const Tensor& weight_packed, - const Tensor& weight, - const Tensor& scales, - const std::optional& bias, - const int64_t n, - const int64_t k, - const int64_t bl) { - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases - if (bl == k) { - // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - auto& params = kernel_packet.rhs_pack_params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - - kai_pack_rhs_channelwise_int4( - kernel_packet, weight_packed, weight, scales, bias, n, k); - } else if (!(bl % 32) && !(k % bl)) { - // Groupwise - auto kernel_packet = kai_select_groupwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); - - const int64_t rhs_stride = kai_roundup(k, 2) / 2; - const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t); - auto& params = kernel_packet.rhs_pack_params; - params.lhs_zero_point = 1; - params.rhs_zero_point = 8; - params.scale_dt = kai_datatype::kai_dt_bf16; - - kai_pack_rhs_groupwise_int4( - kernel_packet, - weight_packed, - weight, - scales, - bias, - n, - k, - bl, - rhs_stride, - scale_stride); - } -} - -size_t kai_pack_rhs_int4_size( - const int64_t n, - const int64_t k, - const int64_t bl) { - size_t packed_size = n * k; - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases - if (bl == k) { - // Channelwise - auto kernel_packet = kai_select_channelwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod); - const auto& ukernel = kernel_packet.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr); - } else if (!(bl % 32) && !(k % bl)) { - // Groupwise - auto kernel_packet = kai_select_groupwise_matmul_ukernel( - kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod); - const auto& ukernel = kernel_packet.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - packed_size = kernel_packet.kai_get_rhs_packed_size( - n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16); - } - return packed_size; -} - -static void matmul_channelwise( - kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet, - size_t m_increment, - size_t m_start, - size_t m_per_thread, - size_t n_start, - size_t n_per_thread, - size_t n, - size_t k, - size_t mr, - size_t nr, - size_t kr, - size_t sr, - size_t dst_stride, - size_t lhs_stride, - uint8_t* lhs_native_mtx_f32, - uint8_t* lhs_packed_mtx_qa8dx, - uint8_t* rhs_packed_mtx_qs4cx, - uint8_t* dst_act_mtx_f32) { - for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) { - const float* src_ptr = - (const float*)(lhs_native_mtx_f32 + - kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32( - m_start + m0, lhs_stride)); - void* lhs_packed_ptr = - (void*)(lhs_packed_mtx_qa8dx + - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( - 0, k, mr, kr, sr)); - const void* rhs_packed_ptr = - (const void*)((const char*)rhs_packed_mtx_qs4cx + - kernel_packet.ukernel.get_rhs_packed_offset(n_start, k)); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + - kernel_packet.ukernel.get_dst_offset( - m_start + m0, n_start, dst_stride)); - - // Quantize and pack the Input - kernel_packet.kai_run_lhs_quant_pack( - m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr); - - // Run Matmul on Int4 packed weights and Quantized Packed Input - kernel_packet.ukernel.run_matmul( - m_increment, - n_per_thread, - k, - lhs_packed_ptr, - rhs_packed_ptr, - dst_ptr, - dst_stride, - sizeof(float), - -FLT_MAX, - FLT_MAX); - } -} - -static void matmul_groupwise( - kai_matmul_ukernel_f32_qa8dxp_qs4c32p& kernel_packet, - size_t m_increment, - size_t m_start, - size_t m_per_thread, - size_t n_start, - size_t n_per_thread, - size_t n, - size_t k, - size_t bl, - size_t mr, - size_t nr, - size_t kr, - size_t sr, - size_t dst_stride, - size_t lhs_stride, - uint8_t* lhs_native_mtx_f32, - uint8_t* lhs_packed_mtx_qa8dx, - uint8_t* rhs_packed_mtx_qs4cx, - uint8_t* dst_act_mtx_f32) { - for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) { - const float* src_ptr = - (const float*)(lhs_native_mtx_f32 + - kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32( - m_start + m0, lhs_stride)); - size_t lhs_packed_offset = - kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(0, k, mr, kr, sr); - void* lhs_packed_ptr = (void*)(lhs_packed_mtx_qa8dx + lhs_packed_offset); - const void* rhs_packed_ptr = - (const void*)((const char*)rhs_packed_mtx_qs4cx + - kernel_packet.ukernel.get_rhs_packed_offset( - n_start, k, bl)); - float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 + - kernel_packet.ukernel.get_dst_offset( - m_start + m0, n_start, dst_stride)); - - // Quantize and pack the Input - kernel_packet.kai_run_lhs_quant_pack( - m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr); - - // Run Matmul on Int4 packed weights and Quantized Packed Input - kernel_packet.ukernel.run_matmul( - m_increment, - n_per_thread, - k, - bl, - lhs_packed_ptr, - rhs_packed_ptr, - dst_ptr, - dst_stride, - sizeof(float), - -FLT_MAX, - FLT_MAX); - } -} - -struct ThreadDivision { - int64_t num_threads_x; - int64_t num_threads_y; - bool can_gemm; // True if GEMM is selected, false if GEMV is used. For Certain - // Configurations, GEMV Kernel might be used even if M>1 -}; - -inline static unsigned int round_down_to_power_of_2(unsigned int n) { - n |= n >> 1; - n |= n >> 2; - n |= n >> 4; - n |= n >> 8; - n |= n >> 16; - return n - (n >> 1); -} - -inline static void adjust_max_threads(int64_t& max_threads) { - // We would not like to round down to nearest power of 2 always - // There can be possible thread split combination between powers of 2 for odd - // shapes - // TODO:: Decide better strategy based on hint of input and weight shapes - max_threads = round_down_to_power_of_2(max_threads); -} - -static std::pair split_2d(const int64_t max_threads) { - int64_t sqrt_threads = std::sqrt(max_threads); - - for (int64_t i = sqrt_threads; i >= 1; --i) { - if (max_threads % i == 0) { - return {i, max_threads / i}; - } - } - - return {1, max_threads}; // Theres still a possibility of 1D blocking when - // calling GEMM kernel -} - -inline static ThreadDivision get_thread_division( - int64_t max_threads, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t gemm_m_step, - const int64_t gemm_n_step, - const int64_t gemv_m_step, - const int64_t gemv_n_step) { - adjust_max_threads(max_threads); - ThreadDivision division{1, 1, false}; - - // Split threads 2D for GEMM - if (m % gemm_m_step == 0 && n % gemm_n_step == 0) { - while (max_threads > 0) { - auto [num_thread_y, num_thread_x] = split_2d(max_threads); - if (m % num_thread_y == 0 && n % num_thread_x == 0) { - int64_t m_per_thread = m / num_thread_y; - int64_t n_per_thread = n / num_thread_x; - if (m_per_thread % gemm_m_step == 0 && - n_per_thread % gemm_n_step == 0) { - division = {num_thread_x, num_thread_y, true}; - return division; - } - } - max_threads -= 2; - } - } - // Split threads 1D for GEMV - if (n % gemv_n_step == 0) { - for (; max_threads > 0; max_threads -= 2) { - if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) { - division.num_threads_x = max_threads; - return division; - } - } - } - return division; -} - -static void kai_quant_pack_lhs_int4_mm_groupwise( - const Tensor& output, - const Tensor& input, - const Tensor& weight, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t bl) { - // Kernel IDs for GEMM and GEMV - kai_kernel_id gemm_id = - kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm; - kai_kernel_id gemv_id = kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod; - - // Get the total number of threads available and choose GEMM or GEMV steps - const int64_t total_threads = at::get_num_threads(); - auto gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id); - if (cpuinfo_has_arm_i8mm()) { - gemm_kernel_packet = kai_select_groupwise_matmul_ukernel(gemm_id); - } - auto gemv_kernel_packet = kai_select_groupwise_matmul_ukernel(gemv_id); - - // Retrieve m_step and n_step values from GEMM and GEMV kernels - const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step(); - const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step(); - const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step(); - const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step(); - // Determine threading and kernel type - ThreadDivision division = get_thread_division( - total_threads, - m, - n, - k, - gemm_m_step, - gemm_n_step, - gemv_m_step, - gemv_n_step); - // Select appropriate kernel packet based on the chosen kernel type - auto& kernel_packet = - division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet; - - // Thread blocking parameters - const size_t mr = kernel_packet.ukernel.get_mr(); - const size_t nr = kernel_packet.ukernel.get_nr(); - const size_t kr = kernel_packet.ukernel.get_kr(); - const size_t sr = kernel_packet.ukernel.get_sr(); - const size_t m_increment = kernel_packet.ukernel.get_m_step(); - const size_t n_per_thread = n / division.num_threads_x; - const size_t m_per_thread = m / division.num_threads_y; - const int64_t num_threads = division.num_threads_y * division.num_threads_x; - const size_t dst_stride = n * sizeof(float); - const size_t lhs_stride = k * sizeof(float); - - const size_t lhs_packed_size = - kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr); - - uint8_t* dst_act_mtx_f32 = reinterpret_cast(output.data_ptr()); - uint8_t* lhs_native_mtx_f32 = reinterpret_cast(input.data_ptr()); - uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast(weight.data_ptr()); - auto lhs_packed = std::make_unique(lhs_packed_size * num_threads); - uint8_t* lhs_packed_base = lhs_packed.get(); - - at::parallel_for( - 0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) { - for (const auto thread_id : c10::irange(begin, end)) { - size_t y = thread_id / division.num_threads_x; - size_t x = thread_id % division.num_threads_x; - uint8_t* lhs_packed_ptr = lhs_packed_base + - (x + y * division.num_threads_x) * lhs_packed_size; - matmul_groupwise( - std::ref(kernel_packet), - m_increment, - /*m_start=*/y * m_per_thread, - m_per_thread, - x * n_per_thread, - n_per_thread, - n, - k, - bl, - mr, - nr, - kr, - sr, - dst_stride, - lhs_stride, - lhs_native_mtx_f32, - lhs_packed_ptr, - rhs_packed_mtx_qs4cx, - dst_act_mtx_f32); - } - }); -} - -static void kai_quant_pack_lhs_int4_mm_channelwise( - const Tensor& output, - const Tensor& input, - const Tensor& weight, - const int64_t m, - const int64_t n, - const int64_t k) { - // Kernel IDs for GEMM and GEMV - kai_kernel_id gemm_id = - kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm; - kai_kernel_id gemv_id = - kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod; - - // Get the total number of threads available and choose GEMM or GEMV steps - const int64_t total_threads = at::get_num_threads(); - auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id); - if (cpuinfo_has_arm_i8mm()) { - gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id); - } - auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id); - - // Retrieve m_step and n_step values from GEMM and GEMV kernels - const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step(); - const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step(); - const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step(); - const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step(); - // Determine threading and kernel type - ThreadDivision division = get_thread_division( - total_threads, - m, - n, - k, - gemm_m_step, - gemm_n_step, - gemv_m_step, - gemv_n_step); - // Select appropriate kernel packet based on the chosen kernel type - auto& kernel_packet = - division.can_gemm ? gemm_kernel_packet : gemv_kernel_packet; - - // Thread blocking parameters - const size_t mr = kernel_packet.ukernel.get_mr(); - const size_t nr = kernel_packet.ukernel.get_nr(); - const size_t kr = kernel_packet.ukernel.get_kr(); - const size_t sr = kernel_packet.ukernel.get_sr(); - const size_t m_increment = kernel_packet.ukernel.get_m_step(); - const size_t n_per_thread = n / division.num_threads_x; - const size_t m_per_thread = m / division.num_threads_y; - const int64_t num_threads = division.num_threads_y * division.num_threads_x; - const size_t dst_stride = n * sizeof(float); - const size_t lhs_stride = k * sizeof(float); - const size_t lhs_packed_size = - kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr); - - uint8_t* dst_act_mtx_f32 = reinterpret_cast(output.data_ptr()); - uint8_t* lhs_native_mtx_f32 = reinterpret_cast(input.data_ptr()); - uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast(weight.data_ptr()); - auto lhs_packed = std::make_unique(lhs_packed_size * num_threads); - uint8_t* lhs_packed_base = lhs_packed.get(); - - at::parallel_for( - 0, num_threads, /*grain_size=*/0, [&](int64_t begin, int64_t end) { - for (const auto thread_id : c10::irange(begin, end)) { - size_t y = thread_id / division.num_threads_x; - size_t x = thread_id % division.num_threads_x; - uint8_t* lhs_packed_ptr = lhs_packed_base + - (x + y * division.num_threads_x) * lhs_packed_size; - matmul_channelwise( - std::ref(kernel_packet), - m_increment, - /*m_start=*/y * m_per_thread, - m_per_thread, - x * n_per_thread, - n_per_thread, - n, - k, - mr, - nr, - kr, - sr, - dst_stride, - lhs_stride, - lhs_native_mtx_f32, - lhs_packed_ptr, - rhs_packed_mtx_qs4cx, - dst_act_mtx_f32); - } - }); -} - -void kai_quant_pack_lhs_int4_mm( - const Tensor& output, - const Tensor& input, - const Tensor& weight, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t bl) { - // Prefer Channelwise kernel over Groupwise kernel for conflicting cases - if (bl == k) { - kleidiai::kai_quant_pack_lhs_int4_mm_channelwise( - output, input, weight, m, n, k); - } else if (!(bl % 32) && !(k % bl)) { - kleidiai::kai_quant_pack_lhs_int4_mm_groupwise( - output, input, weight, m, n, k, bl); - } -} -} // namespace at::native::kleidiai -#endif diff --git a/aten/src/ATen/native/kleidiai/kai_kernels.h b/aten/src/ATen/native/kleidiai/kai_kernels.h deleted file mode 100644 index 9b522d7f7705..000000000000 --- a/aten/src/ATen/native/kleidiai/kai_kernels.h +++ /dev/null @@ -1,42 +0,0 @@ -#pragma once -#include -#include -#if AT_KLEIDIAI_ENABLED() - -namespace at::native::kleidiai { - -/** - * @brief Rearranges the quantized weight to support kleidiai inference - * @param bl Groupsize for quantization should be multiple of 32 - */ -void kai_pack_int4_rhs( - const Tensor& weight_packed, - const Tensor& weight, - const Tensor& scales, - const std::optional& bias, - const int64_t n, - const int64_t k, - const int64_t bl); - -/** - * @brief Outputs the buffer size for the packed weights - * @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise - */ -size_t kai_pack_rhs_int4_size( - const int64_t n, - const int64_t k, - const int64_t bl); - -/** - * @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul ) - */ -void kai_quant_pack_lhs_int4_mm( - const Tensor& output, - const Tensor& input, - const Tensor& weight, - const int64_t m, - const int64_t n, - const int64_t k, - const int64_t bl); -} // namespace at::native::kleidiai -#endif diff --git a/aten/src/ATen/native/kleidiai/kai_pack.h b/aten/src/ATen/native/kleidiai/kai_pack.h deleted file mode 100644 index 4ff3371ab5e2..000000000000 --- a/aten/src/ATen/native/kleidiai/kai_pack.h +++ /dev/null @@ -1,106 +0,0 @@ -#pragma once -#include -#include -#include -#include -#if AT_KLEIDIAI_ENABLED() - -namespace at::native::kleidiai { - -template -void kai_pack_rhs_groupwise_int4( - T& kernel, - const Tensor& weight_packed, - const Tensor& weight, - const Tensor& scales, - const std::optional& bias, - const int64_t n, - const int64_t k, - const int64_t bl, - const int64_t rhs_stride, - const int64_t scale_stride) { - const auto& ukernel = kernel.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - auto weight_packed_data = - reinterpret_cast(weight_packed.data_ptr()); - const auto weight_data = weight.data_ptr(); - auto scales_data = scales.const_data_ptr(); - - if (weight_data == nullptr) { - AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); - } - - if (scales_data == nullptr) { - AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); - } - - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; - auto& params = kernel.rhs_pack_params; - - kernel.kai_run_rhs_pack( - /*num_groups=*/1, - n, - k, - nr, - kr, - sr, - bl, - (const uint8_t*)(weight_data), - rhs_stride, - bias_ptr, - scales_data, - scale_stride, - weight_packed_data, - 0, - ¶ms); -} - -template -void kai_pack_rhs_channelwise_int4( - T& kernel, - const Tensor& weight_packed, - const Tensor& weight, - const Tensor& scales, - const std::optional& bias, - const int64_t n, - const int64_t k) { - const auto& ukernel = kernel.ukernel; - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - auto weight_packed_data = - reinterpret_cast(weight_packed.data_ptr()); - const auto weight_data = weight.data_ptr(); - const auto scales_data = scales.data_ptr(); - - if (weight_data == nullptr) { - AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null"); - } - - if (scales_data == nullptr) { - AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null"); - } - - float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : NULL; - auto& params = kernel.rhs_pack_params; - - kernel.kai_run_rhs_pack( - /*num_groups=*/1, - n, - k, - nr, - kr, - sr, - (const uint8_t*)(weight_data), - (const float*)(bias_ptr), - (const float*)(scales_data), - weight_packed_data, - 0, - ¶ms); -} - -} // namespace at::native::kleidiai - -#endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp deleted file mode 100644 index 0de198d7dc01..000000000000 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include - -#if AT_KLEIDIAI_ENABLED() - -namespace at::native::kleidiai { - -// Kernel Mapping - Groupwise -std::unordered_map groupwise_8bit_4bit_kernels = - {{kai_kernel_id:: - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}}, - {kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm, - {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}}; - -kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( - kai_kernel_id id) { - return groupwise_8bit_4bit_kernels.at(id); -} - -// Kernel Mapping - Channelwise -std::unordered_map channelwise_8bit_4bit_kernels = - {{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - {{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod, - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}}, - {kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - {{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm, - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}}; - -kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel( - const kai_kernel_id id) { - return channelwise_8bit_4bit_kernels.at(id); -} -} // namespace at::native::kleidiai -#endif diff --git a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h b/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h deleted file mode 100644 index c0835729f88b..000000000000 --- a/aten/src/ATen/native/kleidiai/kai_ukernel_interface.h +++ /dev/null @@ -1,144 +0,0 @@ -#pragma once -#include -#include -#if AT_KLEIDIAI_ENABLED() - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace at::native::kleidiai { - -enum class kai_kernel_id { - matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod = - 0, // Groupwise 4 bit GEMV - matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm = - 1, // Groupwise 4 bit GEMM - matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod = - 2, // Channelwise 4 bit GEMV - matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm = - 3 // Channelwise 4 bit GEMM -}; - -// Channelwise Kernel mapping -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp { - struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel; - struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params; - size_t (*kai_get_lhs_packed_size)( - size_t m, - size_t k, - size_t mr, - size_t kr, - size_t sr); - size_t (*kai_get_rhs_packed_size)( - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr); - void (*kai_run_lhs_quant_pack)( - size_t m, - size_t k, - size_t mr, - size_t kr, - size_t sr, - size_t m_idx_start, - const float* lhs, - size_t lhs_stride, - void* lhs_packed); - void (*kai_run_rhs_pack)( - size_t num_groups, - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr, - const uint8_t* rhs, - const float* bias, - const float* scale, - void* rhs_packed, - size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); - - kai_matmul_ukernel_f32_qa8dxp_qs4cxp( - const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel) - : ukernel(kernel), - kai_get_lhs_packed_size( - &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), - kai_get_rhs_packed_size( - &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0), - kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {} -}; - -struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp -kai_select_channelwise_matmul_ukernel(const kai_kernel_id id); - -// Groupwise Kernel mapping -struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p { - struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel; - struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params; - size_t (*kai_get_lhs_packed_size)( - size_t m, - size_t k, - size_t mr, - size_t kr, - size_t sr); - size_t (*kai_get_rhs_packed_size)( - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr, - size_t bl, - enum kai_datatype scale_dt); - void (*kai_run_lhs_quant_pack)( - size_t m, - size_t k, - size_t mr, - size_t kr, - size_t sr, - size_t m_idx_start, - const float* lhs, - size_t lhs_stride, - void* lhs_packed); - void (*kai_run_rhs_pack)( - size_t num_groups, - size_t n, - size_t k, - size_t nr, - size_t kr, - size_t sr, - size_t bl, - const uint8_t* rhs, - size_t rhs_stride, - const float* bias, - const void* scale, - size_t scale_stride, - void* rhs_packed, - size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params); - - kai_matmul_ukernel_f32_qa8dxp_qs4c32p( - const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel) - : ukernel(kernel), - kai_get_lhs_packed_size( - &kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32), - kai_get_rhs_packed_size( - &kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0), - kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32), - kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {} -}; - -struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel( - const kai_kernel_id id); - -} // namespace at::native::kleidiai -#endif diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2a3f152f5653..2607edb9bcda 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4180,14 +4180,6 @@ dispatch: CPU: _weight_int4pack_mm_cpu -- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor - dispatch: - CPU: _dyn_quant_pack_4bit_weight_cpu - -- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor - dispatch: - CPU: _dyn_quant_matmul_4bit_cpu - - func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor dispatch: CPU: _weight_int8pack_mm_cpu diff --git a/buckbuild.bzl b/buckbuild.bzl index 17153b5df77b..0b1fe3de05ed 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -1074,7 +1074,6 @@ def define_buck_targets( ], ) - # TODO: Enable support for KleidiAI bazel build # @lint-ignore BUCKLINT fb_native.genrule( name = "generate_aten_config", @@ -1127,9 +1126,6 @@ def define_buck_targets( "--replace", "@AT_BLAS_USE_CBLAS_DOT@", "AT_BLAS_USE_CBLAS_DOT_FBXPLAT", - "--replace", - "@AT_KLEIDIAI_ENABLED@", - "0", ]), outs = { "Config.h": ["Config.h"], diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 9342555d9bc7..a6c45c51f10a 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -152,7 +152,6 @@ endif() set(AT_MKLDNN_ACL_ENABLED 0) set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) -set(AT_KLEIDIAI_ENABLED 0) # setting default preferred BLAS options if not already present. if(NOT INTERN_BUILD_MOBILE) set(BLAS "MKL" CACHE STRING "Selected BLAS library") @@ -1469,35 +1468,6 @@ if(NOT INTERN_BUILD_MOBILE) message("disabling MKLDNN because USE_MKLDNN is not set") endif() - if(USE_KLEIDIAI) - if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" ) - message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.") - endif() - if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" ) - message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.") - endif() - set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) - set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE) - set(AT_KLEIDIAI_ENABLED 1) - set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests - set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai") - add_subdirectory(${KLEIDIAI_SRC}) - set(KLEIDIAI_INCLUDE_DIRS - ${KLEIDIAI_SRC}/ - ${KLEIDIAI_SRC}/kai/ - ${KLEIDIAI_SRC}/kai/ukernels/ - ${KLEIDIAI_SRC}/kai/ukernels/matmul/ - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/ - ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/ - ) - include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS}) - list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai) - # Recover build options. - set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE) - endif() - if(UNIX AND NOT APPLE) include(CheckLibraryExists) # https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830 diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index b46560e123ba..7c3f02b491af 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -154,9 +154,6 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}") message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}") endif() - if(${USE_KLEIDIAI}) - message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}") - endif() message(STATUS " USE_UCC : ${USE_UCC}") if(${USE_UCC}) message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index 855edd350818..8f2b2c30aee6 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -97,10 +97,6 @@ else() append_torchlib_if_found(microkernels-prod) endif() - if(@USE_KLEIDIAI@) - append_torchlib_if_found(kleidiai) - endif() - append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc) append_torchlib_if_found(onnx onnx_proto) diff --git a/docs/source/backends.rst b/docs/source/backends.rst index de11a3c95748..75ab9e2672f2 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -205,7 +205,6 @@ torch.backends.openmp .. add anything to the rendered page for now. .. py:module:: torch.backends.quantized .. py:module:: torch.backends.xnnpack -.. py:module:: torch.backends.kleidiai torch.backends.opt_einsum diff --git a/setup.py b/setup.py index bc97c0df79aa..f22b3ffec30b 100644 --- a/setup.py +++ b/setup.py @@ -1221,7 +1221,6 @@ def main(): "include/ATen/native/cuda/*.cuh", "include/ATen/native/hip/*.h", "include/ATen/native/hip/*.cuh", - "include/ATen/native/kleidiai/*.h", "include/ATen/native/mps/*.h", "include/ATen/native/mkldnn/xpu/*.h", "include/ATen/native/mkldnn/xpu/detail/*.h", diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index c75b2b05087e..0bcb572aad24 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -92,8 +92,6 @@ aten::_dimI aten::_dimV aten::_dirichlet_grad aten::_dirichlet_grad.out -aten::_dyn_quant_matmul_4bit -aten::_dyn_quant_pack_4bit_weight aten::_efficient_attention_backward aten::_efficient_attention_forward aten::_efficientzerotensor diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 189a3d86ce0b..7e86f5aaa416 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -79,7 +79,6 @@ from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_dtype import all_types, get_all_dtypes from torch.testing._internal.common_quantization import ( _dynamically_quantize_per_channel, - _group_quantize_tensor_symmetric, ) from torch.testing._internal.common_utils import ( DeterministicGuard, @@ -2269,85 +2268,6 @@ class CommonTemplate: b_int8pack, b_scales = convert_weight_to_int8pack(b) self.common(fn, (a, b_int8pack, b_scales, c)) - @xfail_if_triton_cpu - @skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA") - @skipIfRocm - @skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU") - def test__dyn_quant_pack_4bit_weight(self): - q_group = 32 - k = 128 - n = 128 - - torch.manual_seed(1) - b = torch.rand((k, n), dtype=torch.float32) - in_features = b.size(0) - out_features = b.size(1) - - def dyn_quant_pack_4bit_weight(b, in_features, out_features): - b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( - b, n_bit=4, groupsize=q_group - ) - - if q_group == in_features: - b_scales_and_zeros = b_scales_and_zeros.to(torch.float) - else: - b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) - b_int4pack = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features - ) - - return b_int4pack, b_scales_and_zeros - - def fn(b, in_features, out_features): - b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) - return b_int4pack - - self.common(fn, (b, in_features, out_features)) - - @xfail_if_triton_cpu - @skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA") - @skipIfRocm - @skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU") - def test__dyn_quant_matmul_4bit(self): - q_group = 32 - m = 32 - k = 128 - n = 128 - - torch.manual_seed(1) - a = torch.rand((m, k), dtype=torch.float32) - b = torch.rand((k, n), dtype=torch.float32) - in_features = b.size(0) - out_features = b.size(1) - - def dyn_quant_pack_4bit_weight(b, in_features, out_features): - b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( - b, n_bit=4, groupsize=q_group - ) - - if q_group == in_features: - b_scales_and_zeros = b_scales_and_zeros.to(torch.float) - else: - b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) - b_int4pack = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features - ) - - return b_int4pack, b_scales_and_zeros - - def fn(a, q_group, in_features, out_features): - b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features) - res = torch._dyn_quant_matmul_4bit( - a, - b_int4pack, - q_group, - in_features, - out_features, - ) - return res - - self.common(fn, (a, q_group, in_features, out_features)) - def test_expanded_reduction(self): def fn(x, y): z = x * y diff --git a/test/test_linalg.py b/test/test_linalg.py index c372ea3772be..a7daaf666284 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -34,8 +34,7 @@ from torch.testing._internal.common_dtype import ( ) from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \ _get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU -from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \ - _group_quantize_tensor_symmetric +from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel from torch.testing._internal.common_mkldnn import bf32_on_and_off from torch.distributions.binomial import Binomial import torch.backends.opt_einsum as opt_einsum @@ -910,6 +909,7 @@ class TestLinalg(TestCase): torch.randn((3, 52, 52), device=device, dtype=dtype), torch.randn((4, 2, 26, 26), device=device, dtype=dtype)) + ops = (torch.det, torch.Tensor.det, torch.linalg.det) for t in tensors: @@ -1438,6 +1438,7 @@ class TestLinalg(TestCase): continue run_test_case(make_arg(shape), ord, dim, keepdim) + @onlyCUDA @dtypes(torch.bfloat16, torch.float16) def test_norm_fused_type_promotion(self, device, dtype): @@ -4343,6 +4344,7 @@ class TestLinalg(TestCase): triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10), upper, unitriangular, transpose) + @slowTest @skipCUDAIfNoMagma @skipCPUIfNoLapack @@ -4463,6 +4465,7 @@ class TestLinalg(TestCase): self.assertTrue("An output with one or more elements was resized" in str(w[0].message)) self.assertTrue("An output with one or more elements was resized" in str(w[1].message)) + def check_single_matmul(self, x, y): def assertEqual(answer, expected): @@ -5686,6 +5689,7 @@ class TestLinalg(TestCase): else: self.assertEqual(B_, X_ @ A) + sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0)) batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5)) # Non pivoting just implemented for CUDA @@ -5718,6 +5722,7 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'): f(torch.empty(1, 2, 2), pivot=False) + @precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2}) @skipCUDAIfNoMagmaAndNoCusolver @skipCPUIfNoLapack @@ -5745,6 +5750,7 @@ class TestLinalg(TestCase): for b, n in shapes: yield make_arg((b, n, n)), make_arg((b, n, rhs)) + for A, B in gen_matrices(): LU, pivots = torch.linalg.lu_factor(A) for backend in backends: @@ -5759,6 +5765,7 @@ class TestLinalg(TestCase): else: self.assertEqual(B_left, X @ A_adj) + @onlyCPU @dtypes(*floating_and_complex_types()) def test_linalg_lu_cpu_errors(self, device, dtype): @@ -5799,6 +5806,7 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): torch.lu_unpack(LU, pivots) + # Rectangular tests sample = torch.randn(2, 3, 5, device=device, dtype=dtype) B = torch.randn(2, 3, 5, device=device, dtype=dtype) @@ -5815,6 +5823,7 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."): torch.lu_unpack(LU, pivots) + @skipCPUIfNoLapack @skipCUDAIfNoMagma @dtypes(torch.double) @@ -6423,6 +6432,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out) self.assertEqual(out, y_ref) + @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA def test_matmul_45724(self, device): @@ -6595,6 +6605,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: torch._int_mm(a_int8, b_int8, out=c_int32_result) self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyNativeDeviceTypes @@ -6657,6 +6668,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) + @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyNativeDeviceTypes @@ -6709,168 +6721,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: mean_err = ((res - ref).abs() / ref).mean() self.assertTrue(mean_err < 0.05) - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") - @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") - @onlyNativeDeviceTypes - @parametrize("k", [64, 256]) - @parametrize("n", [32, 48, 64, 128]) - def test__dyn_quant_pack_4bit_weight(self, device, k, n): - # TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead - # Weight shape is [K x N] - if self.device_type == "cuda": - self.skipTest("CUDA Backend is unsupported") - - torch.manual_seed(1) - block_size = 32 - b = torch.rand((k, n), dtype=torch.bfloat16, device=device) - in_features = b.size(0) - out_features = b.size(1) - b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( - b, n_bit=4, groupsize=block_size - ) - b_int4pack = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features - ) - b_int4pack_meta = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features - ) - self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape) - - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") - @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") - @onlyNativeDeviceTypes - @parametrize("m", [1, 32]) - @parametrize("k", [64, 128]) - @parametrize("n", [4096, 11008]) - def test__dyn_quant_matmul_4bit(self, device, m, k, n): - if self.device_type == "cuda": - self.skipTest("CUDA is unsupported") - - q_group = 32 - - torch.manual_seed(1) - a_float32 = torch.rand((m, k), dtype=torch.float32, device=device) - b_float32 = torch.rand((k, n), dtype=torch.float32, device=device) - in_features = b_float32.size(0) - out_features = b_float32.size(1) - - def dyn_quant_pack_4bit_weight(b, in_features, out_features): - b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( - b, n_bit=4, groupsize=q_group - ) - - if q_group == in_features: - b_scales_and_zeros = b_scales_and_zeros.to(torch.float) - else: - b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16) - b_int4pack = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features - ) - - return b_int4pack, b_scales_and_zeros - - def dyn_quant_matmul_4bit( - a, b_int4pack, q_group, in_features, out_features - ): - return torch._dyn_quant_matmul_4bit( - a, - b_int4pack, - q_group, - in_features, - out_features, - ) - - b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight( - b_float32, in_features, out_features - ) - - dtypes = [torch.float32] - - for dtype in dtypes: - a = a_float32.to(dtype=dtype) - b = b_float32.to(dtype=dtype) - ref = torch.mm(a, b) - res = dyn_quant_matmul_4bit( - a, - b_int4pack, - q_group, - in_features, - out_features, - ) - mean_err = ((res - ref).abs() / ref).mean() - self.assertTrue(mean_err < 0.05) - elementwise_diff = (res - ref).abs() - elementwise_relative_error = elementwise_diff / ref.abs().clamp( - min=torch.finfo(ref.dtype).eps - ) - all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06) - self.assertTrue( - all_elements_within_threshold, "Some elements have error >= 0.06" - ) - - @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") - @unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported") - @onlyNativeDeviceTypes - @parametrize("m", [1, 32]) - @parametrize("k", [64, 128]) - @parametrize("n", [4096, 11008]) - def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n): - if self.device_type == "cuda": - self.skipTest("CUDA is unsupported") - - q_group = 32 - - torch.manual_seed(1) - a_float32 = torch.rand((m, k), dtype=torch.float32, device=device) - b_float32 = torch.rand((k, n), dtype=torch.float32, device=device) - in_features = b_float32.size(0) - out_features = b_float32.size(1) - - b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric( - b_float32, n_bit=4, groupsize=q_group - ) - - if q_group == in_features: - b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float) - else: - b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16) - - @torch.compile - def dyn_quant_matmul_4bit( - a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features - ): - b_int4pack = torch._dyn_quant_pack_4bit_weight( - b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features - ) - return torch._dyn_quant_matmul_4bit( - a, - b_int4pack, - q_group, - in_features, - out_features, - ) - - res = dyn_quant_matmul_4bit( - a_float32, - b_uint8, - b_scales_and_zeros, - q_group, - in_features, - out_features, - ) - ref = torch.mm(a_float32, b_float32) - - mean_err = ((res - ref).abs() / ref).mean() - self.assertTrue(mean_err < 0.05) - elementwise_diff = (res - ref).abs() - elementwise_relative_error = elementwise_diff / ref.abs().clamp( - min=torch.finfo(ref.dtype).eps - ) - all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06) - self.assertTrue( - all_elements_within_threshold, "Some elements have error >= 0.06" - ) - @onlyCPU @parametrize("m", [32, 64]) @parametrize("k", [32, 64]) @@ -8802,6 +8652,8 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: self.assertEqual(ck_out, cpu_out) + + def test_permute_matmul(self): a = torch.ones([2, 5, 24, 24]) b = torch.ones([3, 2, 5, 24, 24]) @@ -8881,6 +8733,7 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2: ref = alpha * A @ B + beta * C self.assertEqual(rc, ref) + @dtypes(torch.float, torch.double) @precisionOverride({torch.float32: 1e-4}) def test_1_sized_with_0_strided(self, device, dtype): diff --git a/third_party/kleidiai b/third_party/kleidiai deleted file mode 160000 index 202603f38a9d..000000000000 --- a/third_party/kleidiai +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 202603f38a9df9d2ded89f12b41ded621c71d4ea diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f4be60bb6ffb..9aef8f129773 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1315,7 +1315,6 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a has_openmp: _bool has_mkl: _bool -_has_kleidiai: _bool _has_mps: _bool has_lapack: _bool _has_cuda: _bool diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d4ee91cbd29f..e8551e5cca61 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -1382,8 +1382,6 @@ torch_c_binding_in_graph_functions = dict.fromkeys( "torch._dim_arange", "torch._dirichlet_grad", "torch._disable_functionalization", - "torch._dyn_quant_matmul_4bit", - "torch._dyn_quant_pack_4bit_weight", "torch._efficientzerotensor", "torch._embedding_bag_forward_only", "torch._embedding_bag", diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 1862140b1265..07778d6346ec 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -94,6 +94,3 @@ def register_woq_mm_ops() -> None: return autotune_select_algorithm( "_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout ) - - lowering.make_fallback(aten._dyn_quant_matmul_4bit) - lowering.make_fallback(aten._dyn_quant_pack_4bit_weight) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 173c6c95b2b6..64b724c5fb7b 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3510,161 +3510,6 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros): return x.new_empty(x.size(0), w.size(0), dtype=x.dtype) -def kai_roundup(a: int, b: int) -> int: - return ((a + b - 1) // b) * b - - -def get_kai_packed_weight_size(n_bits, N, K, groupsize): - if n_bits == 4: - if groupsize == K: # channelwise - # dotprod params only [1x8x32_neon_dotprod] - kai_nr = 8 - kai_kr = 16 - kai_sr = 2 - kai_num_bytes_sum_rhs = 4 # sizeof(int32_t) - kai_num_bytes_multiplier_rhs = 4 # sizeof(float) - kai_num_bytes_bias = 4 # sizeof(float) - - def kai_k_roundedup(k, kr, sr): - # Since we pack a float and int32 value at the end of the row, - # we must make sure that k is a multiple of 4 for alignment - kr_sr_roundedup4 = kai_roundup(kr * sr, 4) - return kai_roundup(k, kr_sr_roundedup4) - - def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - k, nr, kr, sr - ): - k_internal = kai_k_roundedup(k, kr, sr) - - assert (k_internal % 2) == 0, "k_internal must be even" - - return nr * ( - (k_internal // 2) - + kai_num_bytes_multiplier_rhs - + kai_num_bytes_sum_rhs - + kai_num_bytes_bias - ) - - def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - n, k, nr, kr, sr - ): - num_rows = kai_roundup(n, nr) // nr - - return ( - num_rows - * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - k, nr, kr, sr - ) - ) - - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( - N, K, kai_nr, kai_kr, kai_sr - ) - elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise - kai_nr = 8 - kai_kr = 16 - kai_sr = 2 - kai_num_bytes_sum_rhs = 4 - kai_num_bytes_bias = 4 - kai_nr_multiple_of = 4 - kai_bl_multiple_of = 32 - - def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - n, k, nr, kr, sr, bl - ): - assert (bl % kr) == 0 - assert (nr % kai_nr_multiple_of) == 0 - assert (bl % kai_bl_multiple_of) == 0 - - num_rows = kai_roundup(n, nr) // nr - - return ( - num_rows - * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - k, nr, kr, sr, bl - ) - ) - - def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - k, nr, kr, sr, bl - ): - assert (bl % kr) == 0 - assert (nr % kai_nr_multiple_of) == 0 - assert (bl % kai_bl_multiple_of) == 0 - - # kr and sr are unused in the calculation - num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes() - num_blocks_per_row = kai_num_blocks_per_row(k, bl) - num_bytes_per_block = kai_num_bytes_per_block( - bl, num_bytes_multiplier_rhs - ) - - return nr * ( - (num_bytes_per_block * num_blocks_per_row) - + kai_num_bytes_sum_rhs - + kai_num_bytes_bias - ) - - # This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype - # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55 - def kai_get_bf16_datatype_size_in_bytes(): - return 2 # 2 bytes - - def kai_num_blocks_per_row(k, bl): - assert (bl % kai_bl_multiple_of) == 0 - return kai_roundup(k, bl) // bl - - def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs): - assert (bl % kai_bl_multiple_of) == 0 - return (bl // 2) + num_bytes_multiplier_rhs - - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0( - N, K, kai_nr, kai_kr, kai_sr, groupsize - ) - - -@register_meta([aten._dyn_quant_pack_4bit_weight]) -def meta__dyn_quant_pack_4bit_weight( - weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features -): - torch._check( - weights.dtype is torch.uint8, - lambda: f"expected w to be uint8, got {weights.dtype}", - ) - if torch.backends.kleidiai.is_available() and ( - (block_size == in_features and scales_zeros.dtype == torch.float) - or ( - block_size < in_features - and block_size % 32 == 0 - and in_features % block_size == 0 - and scales_zeros.dtype == torch.bfloat16 - ) - ): - packed_weight_size = get_kai_packed_weight_size( - 4, out_features, in_features, block_size - ) - return weights.new_empty(int(packed_weight_size), dtype=torch.uint8) - packed_weight_size = weights.numel() + scales_zeros.numel() - return weights.new_empty(packed_weight_size, dtype=torch.float) - - -@register_meta([aten._dyn_quant_matmul_4bit]) -def meta__dyn_quant_matmul_4bit( - inp, - packed_weights, - block_size, - in_features, - out_features, -): - torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor") - torch._check( - inp.dtype in [torch.float32], - lambda: f"expected input to be f32, got {inp.dtype}", - ) - M = inp.size(0) - return inp.new_empty(M, out_features, dtype=inp.dtype) - - @register_meta([aten._weight_int8pack_mm]) def meta__weight_int8pack_mm(x, w, q_scales): torch._check(x.dim() == 2, lambda: "x must be a 2D tensor") diff --git a/torch/backends/__init__.py b/torch/backends/__init__.py index 90166913e324..301735188869 100644 --- a/torch/backends/__init__.py +++ b/torch/backends/__init__.py @@ -62,7 +62,6 @@ from torch.backends import ( cuda as cuda, cudnn as cudnn, cusparselt as cusparselt, - kleidiai as kleidiai, mha as mha, mkl as mkl, mkldnn as mkldnn, diff --git a/torch/backends/kleidiai/__init__.py b/torch/backends/kleidiai/__init__.py deleted file mode 100644 index 1a681b77ef58..000000000000 --- a/torch/backends/kleidiai/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# mypy: allow-untyped-defs -import torch - - -def is_available(): - r"""Return whether PyTorch is built with KleidiAI support.""" - return torch._C._has_kleidiai diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 2230b15aeb3a..957461c6d161 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -1953,8 +1953,6 @@ Call this whenever a new thread is created in order to propagate values from ASSERT_TRUE( set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False)); ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False)); - ASSERT_TRUE( - set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False)); ASSERT_TRUE( set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False)); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index 2a5eb60e9c89..4f6e930eaa61 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -18,8 +18,6 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0); diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 6b52bd280f42..431b442e1f38 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -498,39 +498,6 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16): return out, scales_and_zeros -def _group_quantize_tensor_symmetric( - w, n_bit=4, groupsize=32 -): - # W is of shape [K x N] - # We transpose W as Quantization is applied on [N x K] - w = w.transpose(0, 1).contiguous() - assert w.dim() == 2 - assert groupsize > 1 - assert w.shape[-1] % groupsize == 0 - # Calculate scale and zeros - to_quant = w.reshape(-1, groupsize) - max_val = to_quant.abs().amax(dim=1, keepdim=True) - eps = torch.finfo(max_val.dtype).eps - max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7 - scales = max_val.clamp(min=eps) / max_int - zeros = torch.zeros_like(scales) - - # Quantize the weight - scales = scales.to(torch.float32).reshape(w.shape[0], -1) - zeros = zeros.to(torch.float32).reshape(w.shape[0], -1) - scales = scales.reshape(-1, 1) - zeros = zeros.reshape(-1, 1) - max_int = 2**n_bit - 1 - w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int) - # We pack 2 signed int4 values in unsigned uint8 container. - # This reduces the weight size by half and improves load perf - out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8) - - scales_and_zeros = scales.squeeze().contiguous() - - return out_uint8, scales_and_zeros - - def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): # source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py # default setup for affine quantization of activations @@ -563,6 +530,7 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype): return quant, scales.to(x_dtype), zero_points + # QuantizationTestCase used as a base class for testing quantization on modules class QuantizationTestCase(TestCase): def setUp(self): diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index dead690831f2..106414f88419 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -44,8 +44,6 @@ inductor_fallback_ops = { "aten.cummin.default", "aten.cumprod.default", "aten.cumsum.default", - "aten._dyn_quant_matmul_4bit.default", - "aten._dyn_quant_pack_4bit_weight.default", "aten._efficient_attention_backward.default", "aten._efficient_attention_forward.default", "aten._efficientzerotensor.default",