mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description: 1. Quantize Linear Layer Weights to 4-bits: Quantize the weights of the Linear layer to 4 bits, using symmetric quantization. Pack two 4-bit weights into one uint8 container. Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32. 2. Prepare Quantized Weights, Scales, and Optional Bias: After quantizing, obtain the quantized_weights, scales, and groupsize. If the original Linear layer has a bias, prepare it as well. 3. Pack the Weights Efficiently: Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias. ```python packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features) ``` Input parameters should include: in_features and out_features (the same as the Linear layer’s corresponding parameters). 4. Perform Dynamic Quantized Matrix Multiplication: Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights. ```python output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features) ``` Inputs required include: The input tensor, packed_weights , groupsize, and the in_features and out_features. API Usage: https://github.com/pytorch/pytorch/issues/143289 Model Perf : 7B Transformer model: Prefill : 340 t/s Decode : 40 t/s 2B Transformer model Prefill : 747 t/s Decode : 80 t/s Tests: python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight Ran 1 test in 0.016s OK python test/test_linalg.py -k test__dyn_quant_matmul_4bit Ran 8 tests in 0.077s OK python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit Ran 8 tests in 11.454s Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452 Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124 Approved by: https://github.com/digantdesai, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
c5ddf5dd90
commit
4b82251011
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -131,3 +131,6 @@
|
|||||||
path = third_party/composable_kernel
|
path = third_party/composable_kernel
|
||||||
url = https://github.com/ROCm/composable_kernel.git
|
url = https://github.com/ROCm/composable_kernel.git
|
||||||
branch = develop
|
branch = develop
|
||||||
|
[submodule "third_party/kleidiai"]
|
||||||
|
path = third_party/kleidiai
|
||||||
|
url = https://git.gitlab.arm.com/kleidi/kleidiai.git
|
||||||
|
@ -254,6 +254,7 @@ filegroup(
|
|||||||
# target that generates these sources...
|
# target that generates these sources...
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Enable support for KleidiAI bazel build
|
||||||
header_template_rule(
|
header_template_rule(
|
||||||
name = "aten_src_ATen_config",
|
name = "aten_src_ATen_config",
|
||||||
src = "aten/src/ATen/Config.h.in",
|
src = "aten/src/ATen/Config.h.in",
|
||||||
@ -273,6 +274,7 @@ header_template_rule(
|
|||||||
"@AT_PARALLEL_NATIVE@": "1",
|
"@AT_PARALLEL_NATIVE@": "1",
|
||||||
"@AT_BLAS_F2C@": "0",
|
"@AT_BLAS_F2C@": "0",
|
||||||
"@AT_BLAS_USE_CBLAS_DOT@": "1",
|
"@AT_BLAS_USE_CBLAS_DOT@": "1",
|
||||||
|
"@AT_KLEIDIAI_ENABLED@": "0",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -377,6 +377,8 @@ cmake_dependent_option(
|
|||||||
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
|
||||||
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
|
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
|
||||||
OFF "USE_CUDA" OFF)
|
OFF "USE_CUDA" OFF)
|
||||||
|
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
|
||||||
|
"CPU_AARCH64" OFF)
|
||||||
|
|
||||||
option(USE_MIMALLOC "Use mimalloc" OFF)
|
option(USE_MIMALLOC "Use mimalloc" OFF)
|
||||||
# Enable third party mimalloc library to improve memory allocation performance
|
# Enable third party mimalloc library to improve memory allocation performance
|
||||||
@ -418,6 +420,8 @@ endif()
|
|||||||
if(WIN32)
|
if(WIN32)
|
||||||
set(USE_TENSORPIPE OFF)
|
set(USE_TENSORPIPE OFF)
|
||||||
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
|
message(WARNING "TensorPipe cannot be used on Windows. Set it to OFF")
|
||||||
|
set(USE_KLEIDIAI OFF)
|
||||||
|
message(WARNING "KleidiAI cannot be used on Windows. Set it to OFF")
|
||||||
|
|
||||||
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
|
if(USE_DISTRIBUTED AND NOT DEFINED ENV{libuv_ROOT})
|
||||||
find_library(
|
find_library(
|
||||||
@ -667,6 +671,9 @@ if(ANDROID
|
|||||||
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
|
message(WARNING "INTERN_BUILD_MOBILE is on, disabling BUILD_LAZY_TS_BACKEND")
|
||||||
set(BUILD_LAZY_TS_BACKEND OFF)
|
set(BUILD_LAZY_TS_BACKEND OFF)
|
||||||
|
|
||||||
|
set(USE_KLEIDIAI OFF)
|
||||||
|
message(WARNING "KleidiAI cannot be used on Mobile builds. Set it to OFF")
|
||||||
|
|
||||||
# Set -ffunction-sections and -fdata-sections so that each method has its own
|
# Set -ffunction-sections and -fdata-sections so that each method has its own
|
||||||
# text section. This allows the linker to remove unused section when the flag
|
# text section. This allows the linker to remove unused section when the flag
|
||||||
# -Wl,-gc-sections is provided at link time.
|
# -Wl,-gc-sections is provided at link time.
|
||||||
|
@ -309,6 +309,12 @@ local_repository(
|
|||||||
path = "third_party/gemmlowp/gemmlowp",
|
path = "third_party/gemmlowp/gemmlowp",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
local_repository(
|
||||||
|
name = "kleidiai",
|
||||||
|
path = "third_party/kleidiai",
|
||||||
|
repo_mapping = {"@com_google_googletest": "@com_google_benchmark"},
|
||||||
|
)
|
||||||
|
|
||||||
### Unused repos start
|
### Unused repos start
|
||||||
|
|
||||||
# `unused` repos are defined to hide bazel files from submodules of submodules.
|
# `unused` repos are defined to hide bazel files from submodules of submodules.
|
||||||
|
@ -199,6 +199,10 @@ endif()
|
|||||||
# XNNPACK
|
# XNNPACK
|
||||||
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
file(GLOB native_xnnpack "native/xnnpack/*.cpp")
|
||||||
|
|
||||||
|
# KLEIDIAI
|
||||||
|
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
|
||||||
|
file(GLOB native_kleidiai_h "native/kleidiai/*.h")
|
||||||
|
|
||||||
# Add files needed from jit folders
|
# Add files needed from jit folders
|
||||||
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
|
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
|
||||||
append_filelist("jit_core_sources" ATen_CORE_SRCS)
|
append_filelist("jit_core_sources" ATen_CORE_SRCS)
|
||||||
@ -228,6 +232,10 @@ endif()
|
|||||||
if(AT_MKL_ENABLED)
|
if(AT_MKL_ENABLED)
|
||||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
|
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
|
||||||
endif()
|
endif()
|
||||||
|
if(AT_KLEIDIAI_ENABLED)
|
||||||
|
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
|
||||||
|
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
|
||||||
|
endif()
|
||||||
if(AT_MKLDNN_ENABLED)
|
if(AT_MKLDNN_ENABLED)
|
||||||
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
|
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
|
||||||
endif()
|
endif()
|
||||||
@ -611,7 +619,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"
|
|||||||
|
|
||||||
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
|
set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
|
||||||
if(NOT INTERN_BUILD_MOBILE)
|
if(NOT INTERN_BUILD_MOBILE)
|
||||||
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
|
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_kleidiai_h} ${native_mps_h} ${native_utils_h} ${miopen_h} ${mkldnn_xpu_h})
|
||||||
# Metal
|
# Metal
|
||||||
if(USE_PYTORCH_METAL_EXPORT)
|
if(USE_PYTORCH_METAL_EXPORT)
|
||||||
# Add files needed from exporting metal models(optimized_for_mobile)
|
# Add files needed from exporting metal models(optimized_for_mobile)
|
||||||
|
@ -19,3 +19,4 @@
|
|||||||
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
|
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
|
||||||
#define AT_BLAS_F2C() @AT_BLAS_F2C@
|
#define AT_BLAS_F2C() @AT_BLAS_F2C@
|
||||||
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
|
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
|
||||||
|
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
|
||||||
|
@ -376,6 +376,10 @@ bool Context::hasMKLDNN() {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Context::hasKleidiAI() {
|
||||||
|
return AT_KLEIDIAI_ENABLED();
|
||||||
|
}
|
||||||
|
|
||||||
bool Context::hasOpenMP() {
|
bool Context::hasOpenMP() {
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
return true;
|
return true;
|
||||||
|
@ -118,6 +118,7 @@ class TORCH_API Context {
|
|||||||
|
|
||||||
static bool hasOpenMP();
|
static bool hasOpenMP();
|
||||||
static bool hasMKL();
|
static bool hasMKL();
|
||||||
|
static bool hasKleidiAI();
|
||||||
static bool hasLAPACK();
|
static bool hasLAPACK();
|
||||||
static bool hasMKLDNN();
|
static bool hasMKLDNN();
|
||||||
static bool hasMAGMA() {
|
static bool hasMAGMA() {
|
||||||
@ -538,6 +539,10 @@ inline bool hasMKL() {
|
|||||||
return globalContext().hasMKL();
|
return globalContext().hasMKL();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool hasKleidiAI() {
|
||||||
|
return globalContext().hasKleidiAI();
|
||||||
|
}
|
||||||
|
|
||||||
inline bool hasLAPACK() {
|
inline bool hasLAPACK() {
|
||||||
return globalContext().hasLAPACK();
|
return globalContext().hasLAPACK();
|
||||||
}
|
}
|
||||||
|
@ -33,6 +33,8 @@
|
|||||||
#include <ATen/ops/_addmm_activation_native.h>
|
#include <ATen/ops/_addmm_activation_native.h>
|
||||||
#include <ATen/ops/_compute_linear_combination_native.h>
|
#include <ATen/ops/_compute_linear_combination_native.h>
|
||||||
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
|
#include <ATen/ops/_convert_weight_to_int4pack_for_cpu_native.h>
|
||||||
|
#include <ATen/ops/_dyn_quant_matmul_4bit_native.h>
|
||||||
|
#include <ATen/ops/_dyn_quant_pack_4bit_weight_native.h>
|
||||||
#include <ATen/ops/_int_mm_native.h>
|
#include <ATen/ops/_int_mm_native.h>
|
||||||
#include <ATen/ops/_linalg_check_errors.h>
|
#include <ATen/ops/_linalg_check_errors.h>
|
||||||
#include <ATen/ops/_linalg_det.h>
|
#include <ATen/ops/_linalg_det.h>
|
||||||
@ -3429,6 +3431,8 @@ Tensor kron(const Tensor& self, const Tensor& other) {
|
|||||||
DEFINE_DISPATCH(weight_to_int4pack_stub);
|
DEFINE_DISPATCH(weight_to_int4pack_stub);
|
||||||
DEFINE_DISPATCH(int4pack_mm_stub);
|
DEFINE_DISPATCH(int4pack_mm_stub);
|
||||||
DEFINE_DISPATCH(int8pack_mm_stub);
|
DEFINE_DISPATCH(int8pack_mm_stub);
|
||||||
|
DEFINE_DISPATCH(dyn_quant_pack_4bit_weight_stub);
|
||||||
|
DEFINE_DISPATCH(dyn_quant_matmul_4bit_stub);
|
||||||
|
|
||||||
Tensor _convert_weight_to_int4pack_cpu(
|
Tensor _convert_weight_to_int4pack_cpu(
|
||||||
const Tensor& in,
|
const Tensor& in,
|
||||||
@ -3492,6 +3496,69 @@ Tensor _weight_int4pack_mm_cpu(
|
|||||||
return C;
|
return C;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor _dyn_quant_pack_4bit_weight_cpu(
|
||||||
|
const Tensor& weights,
|
||||||
|
const Tensor& scales_zeros,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t block_size,
|
||||||
|
const int64_t in_features,
|
||||||
|
const int64_t out_features) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
weights.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
block_size == in_features ||
|
||||||
|
(!(block_size % 32) && !(in_features % block_size)),
|
||||||
|
__func__,
|
||||||
|
": Group size should be multiple of 32, in_features [",
|
||||||
|
in_features,
|
||||||
|
"]. Provided ",
|
||||||
|
block_size);
|
||||||
|
Tensor packed_weights =
|
||||||
|
at::empty(weights.sizes(), weights.options().dtype(at::kByte));
|
||||||
|
dyn_quant_pack_4bit_weight_stub(
|
||||||
|
kCPU,
|
||||||
|
packed_weights,
|
||||||
|
weights,
|
||||||
|
scales_zeros,
|
||||||
|
bias,
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
block_size);
|
||||||
|
return packed_weights;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor _dyn_quant_matmul_4bit_cpu(
|
||||||
|
const Tensor& inp,
|
||||||
|
const Tensor& packed_weights,
|
||||||
|
const int64_t block_size,
|
||||||
|
const int64_t in_features,
|
||||||
|
const int64_t out_features) {
|
||||||
|
auto M = inp.size(0);
|
||||||
|
TORCH_CHECK(
|
||||||
|
inp.dtype() == kFloat,
|
||||||
|
__func__,
|
||||||
|
" : expect input to be 32-bit float tensor.");
|
||||||
|
TORCH_CHECK(
|
||||||
|
block_size == in_features ||
|
||||||
|
(!(block_size % 32) && !(in_features % block_size)),
|
||||||
|
__func__,
|
||||||
|
": Group size should be multiple of 32, in_features [",
|
||||||
|
in_features,
|
||||||
|
"]. Provided ",
|
||||||
|
block_size);
|
||||||
|
auto output = at::empty({M, out_features}, inp.options());
|
||||||
|
dyn_quant_matmul_4bit_stub(
|
||||||
|
kCPU,
|
||||||
|
output,
|
||||||
|
inp,
|
||||||
|
packed_weights,
|
||||||
|
M,
|
||||||
|
out_features,
|
||||||
|
in_features,
|
||||||
|
block_size);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor _weight_int8pack_mm_cpu(
|
Tensor _weight_int8pack_mm_cpu(
|
||||||
const Tensor& A,
|
const Tensor& A,
|
||||||
const Tensor& B,
|
const Tensor& B,
|
||||||
|
@ -8,8 +8,19 @@
|
|||||||
#include <ATen/cpu/vec/vec.h>
|
#include <ATen/cpu/vec/vec.h>
|
||||||
#include <ATen/native/cpu/int_mm_kernel.h>
|
#include <ATen/native/cpu/int_mm_kernel.h>
|
||||||
#include <ATen/native/cpu/utils.h>
|
#include <ATen/native/cpu/utils.h>
|
||||||
#include <c10/util/irange.h>
|
|
||||||
#include <c10/util/Unroll.h>
|
#include <c10/util/Unroll.h>
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
#include <ATen/Functions.h>
|
||||||
|
#else
|
||||||
|
#include <ATen/ops/cat.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
#include <ATen/native/kleidiai/kai_kernels.h>
|
||||||
|
#include <cpuinfo.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
#if (defined(_WIN32) || defined(_WIN64))
|
#if (defined(_WIN32) || defined(_WIN64))
|
||||||
#define RESTRICT __restrict
|
#define RESTRICT __restrict
|
||||||
@ -762,10 +773,457 @@ void int4pack_mm_kernel(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
bool can_use_kleidiai(
|
||||||
|
const at::Tensor& scales_zeros,
|
||||||
|
const int64_t K,
|
||||||
|
const int64_t block_size) {
|
||||||
|
bool ret = false;
|
||||||
|
if (cpuinfo_has_arm_neon_dot()) {
|
||||||
|
// The Groupwise kernel requires BFloat16 Scales and Channelwise kernel
|
||||||
|
// requires Float32 Scales. If not provided, we will use fallback
|
||||||
|
// implementation.
|
||||||
|
if ((block_size == K && scales_zeros.dtype() == at::kFloat) ||
|
||||||
|
((block_size < K && !(block_size % 32) && !(K % block_size)) &&
|
||||||
|
scales_zeros.dtype() == at::kBFloat16)) {
|
||||||
|
ret = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Int4 quantized weights must be represented as a uint8 tensor
|
||||||
|
* For matrix multiplication with a weight shape of (N x K)
|
||||||
|
* the shape of the 4-bit quantized weights is [N, K/groupsize, groupsize/2].
|
||||||
|
*
|
||||||
|
* For KleidiAI weight packing, the scales, biases, and Int4 quantized
|
||||||
|
* weights are packed into a single `packed_weights` structure, optimized for
|
||||||
|
* Arm instructions.
|
||||||
|
*
|
||||||
|
* In the fallback reference kernel, no special packing is required for
|
||||||
|
* Int4 quantized weights.
|
||||||
|
*
|
||||||
|
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||||
|
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||||
|
*/
|
||||||
|
void dyn_quant_pack_4bit_weight_kernel(
|
||||||
|
Tensor& packed_weights,
|
||||||
|
const Tensor& weights,
|
||||||
|
const Tensor& scales_zeros,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t N,
|
||||||
|
const int64_t K,
|
||||||
|
const int64_t block_size) {
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
if (can_use_kleidiai(scales_zeros, K, block_size)) {
|
||||||
|
const int64_t weight_packed_size =
|
||||||
|
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||||
|
packed_weights.resize_({weight_packed_size});
|
||||||
|
kleidiai::kai_pack_int4_rhs(
|
||||||
|
packed_weights, weights, scales_zeros, bias, N, K, block_size);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
TORCH_CHECK(
|
||||||
|
bias.has_value() == 0,
|
||||||
|
__func__,
|
||||||
|
" : Bias is unsupported in reference implementation");
|
||||||
|
packed_weights = packed_weights.to(kFloat);
|
||||||
|
auto weight_reshaped = weights.view({-1}).to(kFloat);
|
||||||
|
auto scales_zeros_reshaped = scales_zeros.view({-1}).to(kFloat);
|
||||||
|
auto res = at::cat({weight_reshaped, scales_zeros_reshaped}, 0);
|
||||||
|
packed_weights.resize_(res.sizes()).copy_(res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||||
|
size_t m,
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
const float* lhs_f32,
|
||||||
|
const uint8_t* rhs_qs4cx,
|
||||||
|
const float* rhs_scales_f32,
|
||||||
|
float* dst_f32,
|
||||||
|
float scalar_min,
|
||||||
|
float scalar_max) {
|
||||||
|
const size_t input_size_8bit = m * (k + sizeof(int32_t) + sizeof(float));
|
||||||
|
|
||||||
|
auto lhs_qa8dx_buffer = std::make_unique<uint8_t[]>(input_size_8bit);
|
||||||
|
uint8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
|
||||||
|
|
||||||
|
// Lambda for quantizing the fp32 input to 8 bit symmetric and pack it in
|
||||||
|
// required format for matmul
|
||||||
|
auto input_quant_pack_8bit_channelwise =
|
||||||
|
[&](size_t m, size_t k, const float* lhs_f32, int8_t* lhs_qa8dx) {
|
||||||
|
const size_t dst_stride =
|
||||||
|
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||||
|
|
||||||
|
const size_t lhs_qa8dx_stride = k;
|
||||||
|
|
||||||
|
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||||
|
const float* src_ptr = lhs_f32 + m_idx * lhs_qa8dx_stride;
|
||||||
|
|
||||||
|
float max0 = -FLT_MAX;
|
||||||
|
float min0 = FLT_MAX;
|
||||||
|
|
||||||
|
// Find min/max for each channel
|
||||||
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
const float src0_0 = src_ptr[k_idx];
|
||||||
|
|
||||||
|
max0 = (std::max)(src0_0, max0);
|
||||||
|
min0 = (std::min)(src0_0, min0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Maximum/minimum int8 values
|
||||||
|
const float qmin = (float)INT8_MIN;
|
||||||
|
const float qmax = (float)INT8_MAX;
|
||||||
|
|
||||||
|
const float rmin0 = (std::min)(0.0f, min0);
|
||||||
|
const float rmax0 = (std::max)(0.0f, max0);
|
||||||
|
|
||||||
|
const float scale0 =
|
||||||
|
rmin0 == rmax0 ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||||
|
|
||||||
|
// Reciprocal to quantize
|
||||||
|
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
|
||||||
|
|
||||||
|
const float descaled_min0 = rmin0 * scale0;
|
||||||
|
const float descaled_max0 = rmax0 * scale0;
|
||||||
|
|
||||||
|
const float zero_point_from_min_error0 = qmin + descaled_min0;
|
||||||
|
const float zero_point_from_max_error0 = qmax + descaled_max0;
|
||||||
|
|
||||||
|
float zero_point0 =
|
||||||
|
zero_point_from_min_error0 + zero_point_from_max_error0 > 0
|
||||||
|
? qmin - descaled_min0
|
||||||
|
: qmax - descaled_max0;
|
||||||
|
|
||||||
|
zero_point0 = (std::max)(zero_point0, qmin);
|
||||||
|
zero_point0 = (std::min)(zero_point0, qmax);
|
||||||
|
|
||||||
|
// Round to nearest integer
|
||||||
|
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||||
|
|
||||||
|
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + m_idx * dst_stride;
|
||||||
|
|
||||||
|
// LHS offset at the beginning of the row
|
||||||
|
*((float*)(dst_ptr)) = recip_scale0;
|
||||||
|
dst_ptr += sizeof(float);
|
||||||
|
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||||
|
dst_ptr += sizeof(int32_t);
|
||||||
|
|
||||||
|
// Quantize the channels
|
||||||
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
const float src0_0 = src_ptr[k_idx];
|
||||||
|
|
||||||
|
// Scale the values
|
||||||
|
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||||
|
|
||||||
|
v0_s32 = v0_s32 + nudged_zero_point0;
|
||||||
|
v0_s32 = (std::max)(v0_s32, static_cast<int32_t>(INT8_MIN));
|
||||||
|
v0_s32 = (std::min)(v0_s32, static_cast<int32_t>(INT8_MAX));
|
||||||
|
dst_ptr[0] = (int8_t)v0_s32;
|
||||||
|
dst_ptr += sizeof(int8_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Dynamically Quantize the float32 input to 8 bit assymetric
|
||||||
|
input_quant_pack_8bit_channelwise(m, k, lhs_f32, (int8_t*)lhs_qa8dx);
|
||||||
|
|
||||||
|
const size_t lhs_stride =
|
||||||
|
k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t);
|
||||||
|
|
||||||
|
const size_t rhs_qs4cx_stride = ((((k + 2 - 1) / 2) * 2) / 2);
|
||||||
|
|
||||||
|
for (size_t m_idx = 0; m_idx < m; ++m_idx) {
|
||||||
|
const int8_t* lhs_ptr_start = (int8_t*)lhs_qa8dx + m_idx * lhs_stride;
|
||||||
|
|
||||||
|
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
||||||
|
// Main f32 accumulator
|
||||||
|
int32_t iacc = 0;
|
||||||
|
|
||||||
|
const int8_t* lhs_ptr = lhs_ptr_start;
|
||||||
|
const uint8_t* rhs_ptr = rhs_qs4cx + n_idx * rhs_qs4cx_stride;
|
||||||
|
|
||||||
|
// Get the LHS quantization parameters stored at the
|
||||||
|
// beginning of each row
|
||||||
|
const float lhs_scale = *(const float*)lhs_ptr;
|
||||||
|
lhs_ptr += sizeof(float);
|
||||||
|
|
||||||
|
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
|
||||||
|
lhs_ptr += sizeof(int32_t);
|
||||||
|
|
||||||
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
// Get the LHS values
|
||||||
|
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
|
||||||
|
|
||||||
|
// Get the RHS values
|
||||||
|
const uint8_t rhs_byte = rhs_ptr[0];
|
||||||
|
|
||||||
|
// Unpack the RHS values
|
||||||
|
int32_t rhs_v0 = 0;
|
||||||
|
if ((k_idx % 2) == 0) {
|
||||||
|
rhs_v0 = (((int32_t)(rhs_byte & 0x0F)) - 8);
|
||||||
|
} else {
|
||||||
|
rhs_v0 = (((int32_t)(rhs_byte >> 4)) - 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
iacc += lhs_v0 * rhs_v0;
|
||||||
|
iacc += lhs_offset * rhs_v0;
|
||||||
|
|
||||||
|
lhs_ptr += 1;
|
||||||
|
|
||||||
|
// Increment only when k_idx is not a multiple of 2
|
||||||
|
rhs_ptr += k_idx % 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the RHS scale
|
||||||
|
const float rhs_scale = rhs_scales_f32[n_idx];
|
||||||
|
|
||||||
|
float main_acc = iacc * rhs_scale;
|
||||||
|
|
||||||
|
main_acc = main_acc * lhs_scale;
|
||||||
|
|
||||||
|
// Clamp (min-max) operation
|
||||||
|
main_acc = (std::max)(main_acc, scalar_min);
|
||||||
|
main_acc = (std::min)(main_acc, scalar_max);
|
||||||
|
|
||||||
|
dst_f32[0] = main_acc;
|
||||||
|
dst_f32 += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
static void ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||||
|
size_t m,
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t bl,
|
||||||
|
const float* lhs_f32,
|
||||||
|
const uint8_t* rhs_qs4c32,
|
||||||
|
const float* rhs_scales_fp32,
|
||||||
|
float* dst_f32,
|
||||||
|
float scalar_min,
|
||||||
|
float scalar_max) {
|
||||||
|
// Lambda for LHS quantization
|
||||||
|
auto lhs_quant_pack = [&](size_t m,
|
||||||
|
size_t k,
|
||||||
|
const float* lhs_f32,
|
||||||
|
int8_t* lhs_qa8dx) {
|
||||||
|
const size_t dst_stride =
|
||||||
|
(k * sizeof(int8_t) + sizeof(float) + sizeof(int32_t));
|
||||||
|
|
||||||
|
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
|
||||||
|
const float* src_ptr = lhs_f32 + row_idx * k;
|
||||||
|
|
||||||
|
float max0 = -FLT_MAX;
|
||||||
|
float min0 = FLT_MAX;
|
||||||
|
|
||||||
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
const float src0_0 = src_ptr[k_idx];
|
||||||
|
max0 = (std::max)(src0_0, max0);
|
||||||
|
min0 = (std::min)(src0_0, min0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float qmin = (float)INT8_MIN;
|
||||||
|
const float qmax = (float)INT8_MAX;
|
||||||
|
|
||||||
|
const float rmin0 = (std::min)(0.0f, min0);
|
||||||
|
const float rmax0 = (std::max)(0.0f, max0);
|
||||||
|
const float scale0 =
|
||||||
|
(rmin0 == rmax0) ? 1.f : (qmax - qmin) / (rmax0 - rmin0);
|
||||||
|
const float recip_scale0 = scale0 ? 1.0f / scale0 : 0.0f;
|
||||||
|
|
||||||
|
const float descaled_min0 = rmin0 * scale0;
|
||||||
|
const float descaled_max0 = rmax0 * scale0;
|
||||||
|
|
||||||
|
float zero_point0 = (qmin + descaled_min0 + qmax + descaled_max0 > 0)
|
||||||
|
? qmin - descaled_min0
|
||||||
|
: qmax - descaled_max0;
|
||||||
|
|
||||||
|
zero_point0 = (std::max)(zero_point0, qmin);
|
||||||
|
zero_point0 = (std::min)(zero_point0, qmax);
|
||||||
|
const int32_t nudged_zero_point0 = lrintf(zero_point0);
|
||||||
|
|
||||||
|
int8_t* dst_ptr = (int8_t*)lhs_qa8dx + row_idx * dst_stride;
|
||||||
|
|
||||||
|
*((float*)(dst_ptr)) = recip_scale0;
|
||||||
|
dst_ptr += sizeof(float);
|
||||||
|
*((int32_t*)(dst_ptr)) = -nudged_zero_point0;
|
||||||
|
dst_ptr += sizeof(int32_t);
|
||||||
|
|
||||||
|
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
||||||
|
const float src0_0 = src_ptr[k_idx];
|
||||||
|
int32_t v0_s32 = (int32_t)(std::round(src0_0 * scale0));
|
||||||
|
v0_s32 = (std::max)(
|
||||||
|
(std::min)(
|
||||||
|
v0_s32 + nudged_zero_point0, static_cast<int32_t>(INT8_MAX)),
|
||||||
|
static_cast<int32_t>(INT8_MIN));
|
||||||
|
dst_ptr[0] = (int8_t)v0_s32;
|
||||||
|
dst_ptr += sizeof(int8_t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
auto lhs_qa8dx_buffer = std::make_unique<int8_t[]>(
|
||||||
|
m * (k + sizeof(float) + sizeof(int32_t))); // Allocate for LHS
|
||||||
|
int8_t* lhs_qa8dx = lhs_qa8dx_buffer.get();
|
||||||
|
// Quantize and pack LHS
|
||||||
|
lhs_quant_pack(m, k, lhs_f32, lhs_qa8dx);
|
||||||
|
|
||||||
|
const size_t num_blocks_row = (((k + bl - 1) / bl) * bl) / bl;
|
||||||
|
const size_t lhs_stride = k + sizeof(float) + sizeof(int32_t);
|
||||||
|
const size_t rhs_stride = (((k + 2 - 1) / 2) * 2) / 2;
|
||||||
|
|
||||||
|
for (size_t row_idx = 0; row_idx < m; ++row_idx) {
|
||||||
|
const int8_t* lhs_ptr_start = lhs_qa8dx + row_idx * lhs_stride;
|
||||||
|
|
||||||
|
for (size_t col_idx = 0; col_idx < n; ++col_idx) {
|
||||||
|
float main_acc = 0.0f;
|
||||||
|
const int8_t* lhs_ptr = lhs_ptr_start;
|
||||||
|
const uint8_t* rhs_ptr = rhs_qs4c32 + col_idx * rhs_stride;
|
||||||
|
|
||||||
|
const float lhs_scale = *(const float*)lhs_ptr;
|
||||||
|
lhs_ptr += sizeof(float);
|
||||||
|
const int32_t lhs_offset = *(const int32_t*)lhs_ptr;
|
||||||
|
lhs_ptr += sizeof(int32_t);
|
||||||
|
|
||||||
|
for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) {
|
||||||
|
const float rhs_scale =
|
||||||
|
rhs_scales_fp32[block_idx + col_idx * num_blocks_row];
|
||||||
|
int32_t iacc = 0;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < bl; ++i) {
|
||||||
|
const size_t k_idx = block_idx * bl + i;
|
||||||
|
if (k_idx >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t lhs_v0 = (int32_t)lhs_ptr[0];
|
||||||
|
const uint8_t rhs_byte = rhs_ptr[0];
|
||||||
|
int32_t rhs_v0 = (k_idx % 2 == 0) ? (((int32_t)(rhs_byte & 0x0F)) - 8)
|
||||||
|
: (((int32_t)(rhs_byte >> 4)) - 8);
|
||||||
|
|
||||||
|
iacc += lhs_v0 * rhs_v0;
|
||||||
|
iacc += lhs_offset * rhs_v0;
|
||||||
|
|
||||||
|
lhs_ptr += 1;
|
||||||
|
rhs_ptr += (k_idx % 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
main_acc += iacc * rhs_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
main_acc = main_acc * lhs_scale;
|
||||||
|
main_acc = (std::max)(main_acc, scalar_min);
|
||||||
|
main_acc = (std::min)(main_acc, scalar_max);
|
||||||
|
|
||||||
|
dst_f32[0] = main_acc;
|
||||||
|
dst_f32 += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dynamic Input Quant 4 bit weights matmul execution flow
|
||||||
|
(INT4 Weights + FP scales + FP32 Bias)
|
||||||
|
FP32 Input Packed Buffer
|
||||||
|
| |
|
||||||
|
Quantize Cast
|
||||||
|
to INT8 to INT8
|
||||||
|
| |
|
||||||
|
v v
|
||||||
|
INT8 Input INT8 Weights
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
\ /
|
||||||
|
INT8 Matrix Multiplication
|
||||||
|
|
|
||||||
|
v
|
||||||
|
FP32 Dequantized and Accumulate in FP32
|
||||||
|
|
|
||||||
|
v
|
||||||
|
FP32 Final Output
|
||||||
|
|
||||||
|
* The Groupwise kernel requires BFloat16 Scales and Channelwise kernel requires
|
||||||
|
* Float32 Scales. If not provided, we will use fallback implementation.
|
||||||
|
*/
|
||||||
|
void dyn_quant_matmul_4bit_kernel(
|
||||||
|
const Tensor& output,
|
||||||
|
const Tensor& inp,
|
||||||
|
const Tensor& packed_weights,
|
||||||
|
const int64_t M,
|
||||||
|
const int64_t N,
|
||||||
|
const int64_t K,
|
||||||
|
const int64_t block_size) {
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
const int64_t weight_packed_size =
|
||||||
|
kleidiai::kai_pack_rhs_int4_size(N, K, block_size);
|
||||||
|
if (weight_packed_size == packed_weights.numel()) {
|
||||||
|
// KleidiAI interface intenally handles the Channelwise and groupwise
|
||||||
|
// distinction
|
||||||
|
kleidiai::kai_quant_pack_lhs_int4_mm(
|
||||||
|
output, inp, packed_weights, M, N, K, block_size);
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
|
float* lhs_f32 = reinterpret_cast<float*>(inp.data_ptr());
|
||||||
|
const auto weights_size = N * K / 2;
|
||||||
|
// The weights needs to be in uint8_t data type after quantization
|
||||||
|
auto extracted_weights =
|
||||||
|
(packed_weights.narrow(0, 0, weights_size)).to(kByte);
|
||||||
|
auto float32_scales =
|
||||||
|
(packed_weights.narrow(
|
||||||
|
0, weights_size, packed_weights.size(0) - weights_size))
|
||||||
|
.to(kFloat);
|
||||||
|
uint8_t* rhs_4bit =
|
||||||
|
reinterpret_cast<uint8_t*>(extracted_weights.data_ptr());
|
||||||
|
float* rhs_scales_f32 = reinterpret_cast<float*>(float32_scales.data_ptr());
|
||||||
|
float* dst_f32 = reinterpret_cast<float*>(output.data_ptr());
|
||||||
|
if (block_size == K) {
|
||||||
|
ref_dyn_quant_matmul_4bit_channelwise_kernel(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
lhs_f32,
|
||||||
|
rhs_4bit,
|
||||||
|
rhs_scales_f32,
|
||||||
|
dst_f32,
|
||||||
|
-FLT_MAX,
|
||||||
|
FLT_MAX);
|
||||||
|
} else if (!(block_size % 32) && !(K % block_size)) {
|
||||||
|
ref_dyn_quant_matmul_4bit_groupwise_kernel(
|
||||||
|
M,
|
||||||
|
N,
|
||||||
|
K,
|
||||||
|
block_size,
|
||||||
|
lhs_f32,
|
||||||
|
rhs_4bit,
|
||||||
|
rhs_scales_f32,
|
||||||
|
dst_f32,
|
||||||
|
-FLT_MAX,
|
||||||
|
FLT_MAX);
|
||||||
|
} else {
|
||||||
|
TORCH_CHECK(
|
||||||
|
block_size == K || (!(block_size % 32) && !(K % block_size)),
|
||||||
|
__func__,
|
||||||
|
": Group size should be multiple 32 or in_features [",
|
||||||
|
K,
|
||||||
|
"]. Provided ",
|
||||||
|
block_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
ALSO_REGISTER_AVX512_DISPATCH(weight_to_int4pack_stub, &weight_to_int4pack_kernel)
|
||||||
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
ALSO_REGISTER_AVX512_DISPATCH(int4pack_mm_stub, &int4pack_mm_kernel)
|
||||||
|
REGISTER_DISPATCH(dyn_quant_pack_4bit_weight_stub, &dyn_quant_pack_4bit_weight_kernel)
|
||||||
|
REGISTER_DISPATCH(dyn_quant_matmul_4bit_stub, &dyn_quant_matmul_4bit_kernel)
|
||||||
|
|
||||||
} // at::native
|
} // at::native
|
||||||
C10_DIAGNOSTIC_POP()
|
C10_DIAGNOSTIC_POP()
|
||||||
|
@ -5,12 +5,34 @@
|
|||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
using weight_to_int4pack_fn = void(*)(const Tensor&, const Tensor&);
|
using weight_to_int4pack_fn = void (*)(const Tensor&, const Tensor&);
|
||||||
using int4pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
|
using int4pack_mm_fn =
|
||||||
using int8pack_mm_fn = void(*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
void (*)(const Tensor&, const Tensor&, const Tensor&, int, const Tensor&);
|
||||||
|
using int8pack_mm_fn =
|
||||||
|
void (*)(const Tensor&, const Tensor&, const Tensor&, const Tensor&);
|
||||||
|
using dyn_quant_pack_4bit_weight_fn = void (*)(
|
||||||
|
Tensor&,
|
||||||
|
const Tensor&,
|
||||||
|
const Tensor&,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t);
|
||||||
|
using dyn_quant_matmul_4bit_fn = void (*)(
|
||||||
|
const Tensor&,
|
||||||
|
const Tensor&,
|
||||||
|
const Tensor&,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t,
|
||||||
|
const int64_t);
|
||||||
|
|
||||||
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
|
DECLARE_DISPATCH(weight_to_int4pack_fn, weight_to_int4pack_stub)
|
||||||
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
|
DECLARE_DISPATCH(int4pack_mm_fn, int4pack_mm_stub)
|
||||||
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
|
DECLARE_DISPATCH(int8pack_mm_fn, int8pack_mm_stub)
|
||||||
|
DECLARE_DISPATCH(
|
||||||
|
dyn_quant_pack_4bit_weight_fn,
|
||||||
|
dyn_quant_pack_4bit_weight_stub);
|
||||||
|
DECLARE_DISPATCH(dyn_quant_matmul_4bit_fn, dyn_quant_matmul_4bit_stub);
|
||||||
|
|
||||||
} // namespace at::native
|
} // namespace at::native
|
||||||
|
440
aten/src/ATen/native/kleidiai/kai_kernels.cpp
Normal file
440
aten/src/ATen/native/kleidiai/kai_kernels.cpp
Normal file
@ -0,0 +1,440 @@
|
|||||||
|
#include <ATen/native/kleidiai/kai_kernels.h>
|
||||||
|
#include <ATen/native/kleidiai/kai_pack.h>
|
||||||
|
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
|
||||||
|
|
||||||
|
#include <ATen/Parallel.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
#include <unordered_map>
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
#include <cpuinfo.h>
|
||||||
|
|
||||||
|
namespace at::native::kleidiai {
|
||||||
|
|
||||||
|
void kai_pack_int4_rhs(
|
||||||
|
const Tensor& weight_packed,
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& scales,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl) {
|
||||||
|
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||||
|
if (bl == k) {
|
||||||
|
// Channelwise
|
||||||
|
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||||
|
kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||||
|
auto& params = kernel_packet.rhs_pack_params;
|
||||||
|
params.lhs_zero_point = 1;
|
||||||
|
params.rhs_zero_point = 8;
|
||||||
|
|
||||||
|
kai_pack_rhs_channelwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4cxp>(
|
||||||
|
kernel_packet, weight_packed, weight, scales, bias, n, k);
|
||||||
|
} else if (!(bl % 32) && !(k % bl)) {
|
||||||
|
// Groupwise
|
||||||
|
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||||
|
kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
|
||||||
|
|
||||||
|
const int64_t rhs_stride = kai_roundup(k, 2) / 2;
|
||||||
|
const int64_t scale_stride = (kai_roundup(k, bl) / bl) * sizeof(uint16_t);
|
||||||
|
auto& params = kernel_packet.rhs_pack_params;
|
||||||
|
params.lhs_zero_point = 1;
|
||||||
|
params.rhs_zero_point = 8;
|
||||||
|
params.scale_dt = kai_datatype::kai_dt_bf16;
|
||||||
|
|
||||||
|
kai_pack_rhs_groupwise_int4<kai_matmul_ukernel_f32_qa8dxp_qs4c32p>(
|
||||||
|
kernel_packet,
|
||||||
|
weight_packed,
|
||||||
|
weight,
|
||||||
|
scales,
|
||||||
|
bias,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
bl,
|
||||||
|
rhs_stride,
|
||||||
|
scale_stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t kai_pack_rhs_int4_size(
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl) {
|
||||||
|
size_t packed_size = n * k;
|
||||||
|
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||||
|
if (bl == k) {
|
||||||
|
// Channelwise
|
||||||
|
auto kernel_packet = kai_select_channelwise_matmul_ukernel(
|
||||||
|
kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod);
|
||||||
|
const auto& ukernel = kernel_packet.ukernel;
|
||||||
|
const size_t nr = ukernel.get_nr();
|
||||||
|
const size_t kr = ukernel.get_kr();
|
||||||
|
const size_t sr = ukernel.get_sr();
|
||||||
|
packed_size = kernel_packet.kai_get_rhs_packed_size(n, k, nr, kr, sr);
|
||||||
|
} else if (!(bl % 32) && !(k % bl)) {
|
||||||
|
// Groupwise
|
||||||
|
auto kernel_packet = kai_select_groupwise_matmul_ukernel(
|
||||||
|
kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod);
|
||||||
|
const auto& ukernel = kernel_packet.ukernel;
|
||||||
|
const size_t nr = ukernel.get_nr();
|
||||||
|
const size_t kr = ukernel.get_kr();
|
||||||
|
const size_t sr = ukernel.get_sr();
|
||||||
|
packed_size = kernel_packet.kai_get_rhs_packed_size(
|
||||||
|
n, k, nr, kr, sr, bl, kai_datatype::kai_dt_bf16);
|
||||||
|
}
|
||||||
|
return packed_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void matmul_channelwise(
|
||||||
|
kai_matmul_ukernel_f32_qa8dxp_qs4cxp& kernel_packet,
|
||||||
|
size_t m_increment,
|
||||||
|
size_t m_start,
|
||||||
|
size_t m_per_thread,
|
||||||
|
size_t n_start,
|
||||||
|
size_t n_per_thread,
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t mr,
|
||||||
|
size_t nr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
size_t dst_stride,
|
||||||
|
size_t lhs_stride,
|
||||||
|
uint8_t* lhs_native_mtx_f32,
|
||||||
|
uint8_t* lhs_packed_mtx_qa8dx,
|
||||||
|
uint8_t* rhs_packed_mtx_qs4cx,
|
||||||
|
uint8_t* dst_act_mtx_f32) {
|
||||||
|
for (size_t m0 = 0; m0 < m_per_thread; m0 += m_increment) {
|
||||||
|
const float* src_ptr =
|
||||||
|
(const float*)(lhs_native_mtx_f32 +
|
||||||
|
kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(
|
||||||
|
m_start + m0, lhs_stride));
|
||||||
|
void* lhs_packed_ptr =
|
||||||
|
(void*)(lhs_packed_mtx_qa8dx +
|
||||||
|
kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(
|
||||||
|
0, k, mr, kr, sr));
|
||||||
|
const void* rhs_packed_ptr =
|
||||||
|
(const void*)((const char*)rhs_packed_mtx_qs4cx +
|
||||||
|
kernel_packet.ukernel.get_rhs_packed_offset(n_start, k));
|
||||||
|
float* dst_ptr = (float*)((uint8_t*)dst_act_mtx_f32 +
|
||||||
|
kernel_packet.ukernel.get_dst_offset(
|
||||||
|
m_start + m0, n_start, dst_stride));
|
||||||
|
|
||||||
|
// Quantize and pack the Input
|
||||||
|
kernel_packet.kai_run_lhs_quant_pack(
|
||||||
|
m_increment, k, mr, kr, sr, 0, src_ptr, lhs_stride, lhs_packed_ptr);
|
||||||
|
|
||||||
|
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||||
|
kernel_packet.ukernel.run_matmul(
|
||||||
|
m_increment,
|
||||||
|
n_per_thread,
|
||||||
|
k,
|
||||||
|
lhs_packed_ptr,
|
||||||
|
rhs_packed_ptr,
|
||||||
|
dst_ptr,
|
||||||
|
dst_stride,
|
||||||
|
sizeof(float),
|
||||||
|
-FLT_MAX,
|
||||||
|
FLT_MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void matmul_groupwise(
|
||||||
|
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel,
|
||||||
|
const size_t m,
|
||||||
|
const size_t num_n_per_thread,
|
||||||
|
const size_t n_start,
|
||||||
|
const size_t k,
|
||||||
|
const size_t bl,
|
||||||
|
const size_t dst_stride,
|
||||||
|
const void* lhs_ptr,
|
||||||
|
uint8_t* rhs_packed,
|
||||||
|
uint8_t* dst_data) {
|
||||||
|
const size_t rhs_packed_offset =
|
||||||
|
ukernel.get_rhs_packed_offset(n_start, k, bl);
|
||||||
|
const size_t dst_offset = ukernel.get_dst_offset(0, n_start, dst_stride);
|
||||||
|
|
||||||
|
const void* rhs_ptr = (const void*)(rhs_packed + rhs_packed_offset);
|
||||||
|
float* dst_ptr = (float*)((uint8_t*)dst_data + dst_offset);
|
||||||
|
|
||||||
|
// Run Matmul on Int4 packed weights and Quantized Packed Input
|
||||||
|
ukernel.run_matmul(
|
||||||
|
m,
|
||||||
|
num_n_per_thread,
|
||||||
|
k,
|
||||||
|
bl,
|
||||||
|
lhs_ptr,
|
||||||
|
rhs_ptr,
|
||||||
|
dst_ptr,
|
||||||
|
dst_stride,
|
||||||
|
sizeof(float),
|
||||||
|
-FLT_MAX,
|
||||||
|
FLT_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ThreadDivision {
|
||||||
|
int64_t num_threads_x;
|
||||||
|
int64_t num_threads_y;
|
||||||
|
bool use_gemm; // True if GEMM is selected, false if GEMV is used
|
||||||
|
};
|
||||||
|
|
||||||
|
inline static unsigned int round_down_to_power_of_2(unsigned int n) {
|
||||||
|
n |= n >> 1;
|
||||||
|
n |= n >> 2;
|
||||||
|
n |= n >> 4;
|
||||||
|
n |= n >> 8;
|
||||||
|
n |= n >> 16;
|
||||||
|
return n - (n >> 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static void adjust_max_threads(int64_t& max_threads) {
|
||||||
|
// We would not like to round down to nearest power of 2 always
|
||||||
|
// There can be possible thread split combination between powers of 2 for odd
|
||||||
|
// shapes
|
||||||
|
// TODO:: Decide better strategy based on hint of input and weight shapes
|
||||||
|
max_threads = round_down_to_power_of_2(max_threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::pair<int64_t, int64_t> split_2d(const int64_t max_threads) {
|
||||||
|
int64_t sqrt_threads = std::sqrt(max_threads);
|
||||||
|
|
||||||
|
for (int64_t i = sqrt_threads; i >= 1; --i) {
|
||||||
|
if (max_threads % i == 0) {
|
||||||
|
return {i, max_threads / i};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {1, max_threads}; // Theres still a possibility of 1D blocking when
|
||||||
|
// calling GEMM kernel
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static ThreadDivision get_thread_division(
|
||||||
|
int64_t max_threads,
|
||||||
|
const int64_t m,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t gemm_m_step,
|
||||||
|
const int64_t gemm_n_step,
|
||||||
|
const int64_t gemv_m_step,
|
||||||
|
const int64_t gemv_n_step) {
|
||||||
|
adjust_max_threads(max_threads);
|
||||||
|
ThreadDivision division{1, 1, false};
|
||||||
|
|
||||||
|
// Split threads 2D for GEMM
|
||||||
|
if (m % gemm_m_step == 0 && n % gemm_n_step == 0) {
|
||||||
|
while (max_threads > 0) {
|
||||||
|
auto [num_thread_y, num_thread_x] = split_2d(max_threads);
|
||||||
|
if (m % num_thread_y == 0 && n % num_thread_x == 0) {
|
||||||
|
int64_t m_per_thread = m / num_thread_y;
|
||||||
|
int64_t n_per_thread = n / num_thread_x;
|
||||||
|
if (m_per_thread % gemm_m_step == 0 &&
|
||||||
|
n_per_thread % gemm_n_step == 0) {
|
||||||
|
division = {num_thread_x, num_thread_y, true};
|
||||||
|
return division;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
max_threads -= 2;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Split threads 1D for GEMV
|
||||||
|
if (n % gemv_n_step == 0) {
|
||||||
|
for (; max_threads > 0; max_threads -= 2) {
|
||||||
|
if (n % max_threads == 0 && (n / max_threads) % gemv_n_step == 0) {
|
||||||
|
division.num_threads_x = max_threads;
|
||||||
|
return division;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return division;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void kai_quant_pack_lhs_int4_mm_groupwise(
|
||||||
|
const Tensor& output,
|
||||||
|
const Tensor& input,
|
||||||
|
const Tensor& weight,
|
||||||
|
const int64_t m,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl) {
|
||||||
|
kai_kernel_id id = kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod;
|
||||||
|
if (cpuinfo_has_arm_i8mm() && m > 1) {
|
||||||
|
id =
|
||||||
|
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm;
|
||||||
|
}
|
||||||
|
auto kernel_packet = kai_select_groupwise_matmul_ukernel(id);
|
||||||
|
|
||||||
|
const auto& ukernel = kernel_packet.ukernel;
|
||||||
|
|
||||||
|
const size_t mr = ukernel.get_mr();
|
||||||
|
const size_t kr = ukernel.get_kr();
|
||||||
|
const size_t sr = ukernel.get_sr();
|
||||||
|
const size_t n_step = ukernel.get_n_step();
|
||||||
|
int64_t total_threads = at::get_num_threads();
|
||||||
|
int64_t num_threads_x = 1;
|
||||||
|
adjust_max_threads(total_threads);
|
||||||
|
// Split threads 1D only for now
|
||||||
|
if (n % n_step == 0) {
|
||||||
|
for (; total_threads > 0; total_threads -= 2) {
|
||||||
|
if (n % total_threads == 0 && (n / total_threads) % n_step == 0) {
|
||||||
|
num_threads_x = total_threads;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_n_per_thread = n / num_threads_x;
|
||||||
|
|
||||||
|
const size_t dst_stride = n * sizeof(float);
|
||||||
|
float* lhs = reinterpret_cast<float*>(input.data_ptr());
|
||||||
|
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||||
|
|
||||||
|
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||||
|
const size_t lhs_packed_size =
|
||||||
|
kernel_packet.kai_get_lhs_packed_size(m, k, mr, kr, sr);
|
||||||
|
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size);
|
||||||
|
|
||||||
|
// Quantize and pack the Input
|
||||||
|
kernel_packet.kai_run_lhs_quant_pack(
|
||||||
|
m,
|
||||||
|
k,
|
||||||
|
mr,
|
||||||
|
kr,
|
||||||
|
sr,
|
||||||
|
0,
|
||||||
|
(const float*)lhs,
|
||||||
|
k * sizeof(float),
|
||||||
|
(void*)lhs_packed.get());
|
||||||
|
|
||||||
|
at::parallel_for(0, num_threads_x, 0, [&](int begin, int end) {
|
||||||
|
for (const auto x : c10::irange(begin, end)) {
|
||||||
|
matmul_groupwise(
|
||||||
|
std::ref(ukernel),
|
||||||
|
m,
|
||||||
|
num_n_per_thread,
|
||||||
|
x * num_n_per_thread,
|
||||||
|
k,
|
||||||
|
bl,
|
||||||
|
dst_stride,
|
||||||
|
lhs_packed.get(),
|
||||||
|
rhs_packed_mtx_qs4cx,
|
||||||
|
dst_act_mtx_f32);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
static void kai_quant_pack_lhs_int4_mm_channelwise(
|
||||||
|
const Tensor& output,
|
||||||
|
const Tensor& input,
|
||||||
|
const Tensor& weight,
|
||||||
|
const int64_t m,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
// Kernel IDs for GEMM and GEMV
|
||||||
|
kai_kernel_id gemm_id =
|
||||||
|
kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm;
|
||||||
|
kai_kernel_id gemv_id =
|
||||||
|
kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod;
|
||||||
|
|
||||||
|
// Get the total number of threads available and choose GEMM or GEMV steps
|
||||||
|
const int64_t total_threads = at::get_num_threads();
|
||||||
|
auto gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
|
||||||
|
if (cpuinfo_has_arm_i8mm()) {
|
||||||
|
gemm_kernel_packet = kai_select_channelwise_matmul_ukernel(gemm_id);
|
||||||
|
}
|
||||||
|
auto gemv_kernel_packet = kai_select_channelwise_matmul_ukernel(gemv_id);
|
||||||
|
|
||||||
|
// Retrieve m_step and n_step values from GEMM and GEMV kernels
|
||||||
|
const int64_t gemm_m_step = gemm_kernel_packet.ukernel.get_m_step();
|
||||||
|
const int64_t gemm_n_step = gemm_kernel_packet.ukernel.get_n_step();
|
||||||
|
const int64_t gemv_m_step = gemv_kernel_packet.ukernel.get_m_step();
|
||||||
|
const int64_t gemv_n_step = gemv_kernel_packet.ukernel.get_n_step();
|
||||||
|
// Determine threading and kernel type
|
||||||
|
ThreadDivision division = get_thread_division(
|
||||||
|
total_threads,
|
||||||
|
m,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
gemm_m_step,
|
||||||
|
gemm_n_step,
|
||||||
|
gemv_m_step,
|
||||||
|
gemv_n_step);
|
||||||
|
// Select appropriate kernel packet based on the chosen kernel type
|
||||||
|
auto& kernel_packet =
|
||||||
|
division.use_gemm ? gemm_kernel_packet : gemv_kernel_packet;
|
||||||
|
|
||||||
|
// Thread blocking parameters
|
||||||
|
const size_t mr = kernel_packet.ukernel.get_mr();
|
||||||
|
const size_t nr = kernel_packet.ukernel.get_nr();
|
||||||
|
const size_t kr = kernel_packet.ukernel.get_kr();
|
||||||
|
const size_t sr = kernel_packet.ukernel.get_sr();
|
||||||
|
const size_t m_increment = kernel_packet.ukernel.get_m_step();
|
||||||
|
const size_t n_per_thread = n / division.num_threads_x;
|
||||||
|
const size_t m_per_thread = m / division.num_threads_y;
|
||||||
|
const int64_t num_threads = division.num_threads_y * division.num_threads_x;
|
||||||
|
const size_t dst_stride = n * sizeof(float);
|
||||||
|
const size_t lhs_stride = k * sizeof(float);
|
||||||
|
|
||||||
|
const size_t lhs_packed_size =
|
||||||
|
kernel_packet.kai_get_lhs_packed_size(m_increment, k, mr, kr, sr);
|
||||||
|
|
||||||
|
uint8_t* dst_act_mtx_f32 = reinterpret_cast<uint8_t*>(output.data_ptr());
|
||||||
|
uint8_t* lhs_native_mtx_f32 = reinterpret_cast<uint8_t*>(input.data_ptr());
|
||||||
|
uint8_t* rhs_packed_mtx_qs4cx = reinterpret_cast<uint8_t*>(weight.data_ptr());
|
||||||
|
auto lhs_packed = std::make_unique<uint8_t[]>(lhs_packed_size * num_threads);
|
||||||
|
uint8_t* lhs_packed_base = lhs_packed.get();
|
||||||
|
|
||||||
|
at::parallel_for(0, num_threads, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
for (const auto i : c10::irange(begin, end)) {
|
||||||
|
size_t y = i / division.num_threads_x;
|
||||||
|
size_t x = i % division.num_threads_x;
|
||||||
|
uint8_t* lhs_packed_ptr =
|
||||||
|
lhs_packed_base + (x + y * division.num_threads_x) * lhs_packed_size;
|
||||||
|
matmul_channelwise(
|
||||||
|
std::ref(kernel_packet),
|
||||||
|
m_increment,
|
||||||
|
y * m_per_thread,
|
||||||
|
m_per_thread,
|
||||||
|
x * n_per_thread,
|
||||||
|
n_per_thread,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
mr,
|
||||||
|
nr,
|
||||||
|
kr,
|
||||||
|
sr,
|
||||||
|
dst_stride,
|
||||||
|
lhs_stride,
|
||||||
|
lhs_native_mtx_f32,
|
||||||
|
lhs_packed_ptr,
|
||||||
|
rhs_packed_mtx_qs4cx,
|
||||||
|
dst_act_mtx_f32);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
void kai_quant_pack_lhs_int4_mm(
|
||||||
|
const Tensor& output,
|
||||||
|
const Tensor& input,
|
||||||
|
const Tensor& weight,
|
||||||
|
const int64_t m,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl) {
|
||||||
|
// Prefer Channelwise kernel over Groupwise kernel for conflicting cases
|
||||||
|
if (bl == k) {
|
||||||
|
kleidiai::kai_quant_pack_lhs_int4_mm_channelwise(
|
||||||
|
output, input, weight, m, n, k);
|
||||||
|
} else if (!(bl % 32) && !(k % bl)) {
|
||||||
|
kleidiai::kai_quant_pack_lhs_int4_mm_groupwise(
|
||||||
|
output, input, weight, m, n, k, bl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace at::native::kleidiai
|
||||||
|
#endif
|
42
aten/src/ATen/native/kleidiai/kai_kernels.h
Normal file
42
aten/src/ATen/native/kleidiai/kai_kernels.h
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/Config.h>
|
||||||
|
#include <ATen/core/Tensor.h>
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
|
||||||
|
namespace at::native::kleidiai {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Rearranges the quantized weight to support kleidiai inference
|
||||||
|
* @param bl Groupsize for quantization should be multiple of 32
|
||||||
|
*/
|
||||||
|
void kai_pack_int4_rhs(
|
||||||
|
const Tensor& weight_packed,
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& scales,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Outputs the buffer size for the packed weights
|
||||||
|
* @param bl Groupsize for quantization. 32 for groupwise , 0 for channelwise
|
||||||
|
*/
|
||||||
|
size_t kai_pack_rhs_int4_size(
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Run 2 operations ( Input quantize and pack -> 4 bit Matmul )
|
||||||
|
*/
|
||||||
|
void kai_quant_pack_lhs_int4_mm(
|
||||||
|
const Tensor& output,
|
||||||
|
const Tensor& input,
|
||||||
|
const Tensor& weight,
|
||||||
|
const int64_t m,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl);
|
||||||
|
} // namespace at::native::kleidiai
|
||||||
|
#endif
|
106
aten/src/ATen/native/kleidiai/kai_pack.h
Normal file
106
aten/src/ATen/native/kleidiai/kai_pack.h
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/Config.h>
|
||||||
|
#include <ATen/core/Tensor.h>
|
||||||
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
|
||||||
|
namespace at::native::kleidiai {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void kai_pack_rhs_groupwise_int4(
|
||||||
|
T& kernel,
|
||||||
|
const Tensor& weight_packed,
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& scales,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k,
|
||||||
|
const int64_t bl,
|
||||||
|
const int64_t rhs_stride,
|
||||||
|
const int64_t scale_stride) {
|
||||||
|
const auto& ukernel = kernel.ukernel;
|
||||||
|
const size_t nr = ukernel.get_nr();
|
||||||
|
const size_t kr = ukernel.get_kr();
|
||||||
|
const size_t sr = ukernel.get_sr();
|
||||||
|
auto weight_packed_data =
|
||||||
|
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||||
|
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||||
|
auto scales_data = scales.const_data_ptr();
|
||||||
|
|
||||||
|
if (weight_data == nullptr) {
|
||||||
|
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scales_data == nullptr) {
|
||||||
|
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||||
|
auto& params = kernel.rhs_pack_params;
|
||||||
|
|
||||||
|
kernel.kai_run_rhs_pack(
|
||||||
|
/*num_groups=*/1,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
nr,
|
||||||
|
kr,
|
||||||
|
sr,
|
||||||
|
bl,
|
||||||
|
(const uint8_t*)(weight_data),
|
||||||
|
rhs_stride,
|
||||||
|
bias_ptr,
|
||||||
|
scales_data,
|
||||||
|
scale_stride,
|
||||||
|
weight_packed_data,
|
||||||
|
0,
|
||||||
|
¶ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void kai_pack_rhs_channelwise_int4(
|
||||||
|
T& kernel,
|
||||||
|
const Tensor& weight_packed,
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& scales,
|
||||||
|
const std::optional<Tensor>& bias,
|
||||||
|
const int64_t n,
|
||||||
|
const int64_t k) {
|
||||||
|
const auto& ukernel = kernel.ukernel;
|
||||||
|
const size_t nr = ukernel.get_nr();
|
||||||
|
const size_t kr = ukernel.get_kr();
|
||||||
|
const size_t sr = ukernel.get_sr();
|
||||||
|
auto weight_packed_data =
|
||||||
|
reinterpret_cast<uint8_t*>(weight_packed.data_ptr());
|
||||||
|
const auto weight_data = weight.data_ptr<uint8_t>();
|
||||||
|
const auto scales_data = scales.data_ptr<float>();
|
||||||
|
|
||||||
|
if (weight_data == nullptr) {
|
||||||
|
AT_ERROR("kai_pack_rhs_channelwise_int4: Weight data pointer is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (scales_data == nullptr) {
|
||||||
|
AT_ERROR("kai_pack_rhs_channelwise_int4: Scales data pointer is null");
|
||||||
|
}
|
||||||
|
|
||||||
|
float* bias_ptr = bias.has_value() ? bias.value().data_ptr<float>() : NULL;
|
||||||
|
auto& params = kernel.rhs_pack_params;
|
||||||
|
|
||||||
|
kernel.kai_run_rhs_pack(
|
||||||
|
/*num_groups=*/1,
|
||||||
|
n,
|
||||||
|
k,
|
||||||
|
nr,
|
||||||
|
kr,
|
||||||
|
sr,
|
||||||
|
(const uint8_t*)(weight_data),
|
||||||
|
(const float*)(bias_ptr),
|
||||||
|
(const float*)(scales_data),
|
||||||
|
weight_packed_data,
|
||||||
|
0,
|
||||||
|
¶ms);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace at::native::kleidiai
|
||||||
|
|
||||||
|
#endif
|
72
aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp
Normal file
72
aten/src/ATen/native/kleidiai/kai_ukernel_interface.cpp
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#include <ATen/native/kleidiai/kai_ukernel_interface.h>
|
||||||
|
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
|
||||||
|
namespace at::native::kleidiai {
|
||||||
|
|
||||||
|
// Kernel Mapping - Groupwise
|
||||||
|
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4c32p> groupwise_8bit_4bit_kernels =
|
||||||
|
{{kai_kernel_id::
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod}}},
|
||||||
|
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm,
|
||||||
|
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm,
|
||||||
|
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm}}}};
|
||||||
|
|
||||||
|
kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||||
|
kai_kernel_id id) {
|
||||||
|
return groupwise_8bit_4bit_kernels.at(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Kernel Mapping - Channelwise
|
||||||
|
std::unordered_map<kai_kernel_id, kai_matmul_ukernel_f32_qa8dxp_qs4cxp> channelwise_8bit_4bit_kernels =
|
||||||
|
{{kai_kernel_id::matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
{{kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod,
|
||||||
|
kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod}}},
|
||||||
|
{kai_kernel_id::matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
{{kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm,
|
||||||
|
kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm}}}};
|
||||||
|
|
||||||
|
kai_matmul_ukernel_f32_qa8dxp_qs4cxp kai_select_channelwise_matmul_ukernel(
|
||||||
|
const kai_kernel_id id) {
|
||||||
|
return channelwise_8bit_4bit_kernels.at(id);
|
||||||
|
}
|
||||||
|
} // namespace at::native::kleidiai
|
||||||
|
#endif
|
144
aten/src/ATen/native/kleidiai/kai_ukernel_interface.h
Normal file
144
aten/src/ATen/native/kleidiai/kai_ukernel_interface.h
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
#pragma once
|
||||||
|
#include <ATen/Config.h>
|
||||||
|
#include <unordered_map>
|
||||||
|
#if AT_KLEIDIAI_ENABLED()
|
||||||
|
|
||||||
|
#include <kai_common.h>
|
||||||
|
#include <kai_lhs_quant_pack_qai8dxp_f32.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h>
|
||||||
|
#include <kai_matmul_clamp_f32_qai8dxp_qsi4cxp_interface.h>
|
||||||
|
#include <kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h>
|
||||||
|
#include <kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h>
|
||||||
|
|
||||||
|
namespace at::native::kleidiai {
|
||||||
|
|
||||||
|
enum class kai_kernel_id {
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod =
|
||||||
|
0, // Groupwise 4 bit GEMV
|
||||||
|
matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_4x8x32_neon_i8mm =
|
||||||
|
1, // Groupwise 4 bit GEMM
|
||||||
|
matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod =
|
||||||
|
2, // Channelwise 4 bit GEMV
|
||||||
|
matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm =
|
||||||
|
3 // Channelwise 4 bit GEMM
|
||||||
|
};
|
||||||
|
|
||||||
|
// Channelwise Kernel mapping
|
||||||
|
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp {
|
||||||
|
struct kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel ukernel;
|
||||||
|
struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params rhs_pack_params;
|
||||||
|
size_t (*kai_get_lhs_packed_size)(
|
||||||
|
size_t m,
|
||||||
|
size_t k,
|
||||||
|
size_t mr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr);
|
||||||
|
size_t (*kai_get_rhs_packed_size)(
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t nr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr);
|
||||||
|
void (*kai_run_lhs_quant_pack)(
|
||||||
|
size_t m,
|
||||||
|
size_t k,
|
||||||
|
size_t mr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
size_t m_idx_start,
|
||||||
|
const float* lhs,
|
||||||
|
size_t lhs_stride,
|
||||||
|
void* lhs_packed);
|
||||||
|
void (*kai_run_rhs_pack)(
|
||||||
|
size_t num_groups,
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t nr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
const uint8_t* rhs,
|
||||||
|
const float* bias,
|
||||||
|
const float* scale,
|
||||||
|
void* rhs_packed,
|
||||||
|
size_t extra_bytes,
|
||||||
|
const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params);
|
||||||
|
|
||||||
|
kai_matmul_ukernel_f32_qa8dxp_qs4cxp(
|
||||||
|
const kai_matmul_clamp_f32_qai8dxp_qsi4cxp_ukernel& kernel)
|
||||||
|
: ukernel(kernel),
|
||||||
|
kai_get_lhs_packed_size(
|
||||||
|
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
|
||||||
|
kai_get_rhs_packed_size(
|
||||||
|
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0),
|
||||||
|
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||||
|
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct kai_matmul_ukernel_f32_qa8dxp_qs4cxp
|
||||||
|
kai_select_channelwise_matmul_ukernel(const kai_kernel_id id);
|
||||||
|
|
||||||
|
// Groupwise Kernel mapping
|
||||||
|
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p {
|
||||||
|
struct kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel;
|
||||||
|
struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params rhs_pack_params;
|
||||||
|
size_t (*kai_get_lhs_packed_size)(
|
||||||
|
size_t m,
|
||||||
|
size_t k,
|
||||||
|
size_t mr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr);
|
||||||
|
size_t (*kai_get_rhs_packed_size)(
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t nr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
size_t bl,
|
||||||
|
enum kai_datatype scale_dt);
|
||||||
|
void (*kai_run_lhs_quant_pack)(
|
||||||
|
size_t m,
|
||||||
|
size_t k,
|
||||||
|
size_t mr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
size_t m_idx_start,
|
||||||
|
const float* lhs,
|
||||||
|
size_t lhs_stride,
|
||||||
|
void* lhs_packed);
|
||||||
|
void (*kai_run_rhs_pack)(
|
||||||
|
size_t num_groups,
|
||||||
|
size_t n,
|
||||||
|
size_t k,
|
||||||
|
size_t nr,
|
||||||
|
size_t kr,
|
||||||
|
size_t sr,
|
||||||
|
size_t bl,
|
||||||
|
const uint8_t* rhs,
|
||||||
|
size_t rhs_stride,
|
||||||
|
const float* bias,
|
||||||
|
const void* scale,
|
||||||
|
size_t scale_stride,
|
||||||
|
void* rhs_packed,
|
||||||
|
size_t extra_bytes,
|
||||||
|
const struct kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params* params);
|
||||||
|
|
||||||
|
kai_matmul_ukernel_f32_qa8dxp_qs4c32p(
|
||||||
|
const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& kernel)
|
||||||
|
: ukernel(kernel),
|
||||||
|
kai_get_lhs_packed_size(
|
||||||
|
&kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32),
|
||||||
|
kai_get_rhs_packed_size(
|
||||||
|
&kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0),
|
||||||
|
kai_run_lhs_quant_pack(&kai_run_lhs_quant_pack_qai8dxp_f32),
|
||||||
|
kai_run_rhs_pack(&kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0) {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct kai_matmul_ukernel_f32_qa8dxp_qs4c32p kai_select_groupwise_matmul_ukernel(
|
||||||
|
const kai_kernel_id id);
|
||||||
|
|
||||||
|
} // namespace at::native::kleidiai
|
||||||
|
#endif
|
@ -4177,6 +4177,14 @@
|
|||||||
dispatch:
|
dispatch:
|
||||||
CPU: _weight_int4pack_mm_cpu
|
CPU: _weight_int4pack_mm_cpu
|
||||||
|
|
||||||
|
- func: _dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor
|
||||||
|
dispatch:
|
||||||
|
CPU: _dyn_quant_pack_4bit_weight_cpu
|
||||||
|
|
||||||
|
- func: _dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor
|
||||||
|
dispatch:
|
||||||
|
CPU: _dyn_quant_matmul_4bit_cpu
|
||||||
|
|
||||||
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
|
- func: _weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU: _weight_int8pack_mm_cpu
|
CPU: _weight_int8pack_mm_cpu
|
||||||
|
@ -1070,6 +1070,7 @@ def define_buck_targets(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: Enable support for KleidiAI bazel build
|
||||||
# @lint-ignore BUCKLINT
|
# @lint-ignore BUCKLINT
|
||||||
fb_native.genrule(
|
fb_native.genrule(
|
||||||
name = "generate_aten_config",
|
name = "generate_aten_config",
|
||||||
@ -1122,6 +1123,9 @@ def define_buck_targets(
|
|||||||
"--replace",
|
"--replace",
|
||||||
"@AT_BLAS_USE_CBLAS_DOT@",
|
"@AT_BLAS_USE_CBLAS_DOT@",
|
||||||
"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
|
"AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
|
||||||
|
"--replace",
|
||||||
|
"@AT_KLEIDIAI_ENABLED@",
|
||||||
|
"0",
|
||||||
]),
|
]),
|
||||||
outs = {
|
outs = {
|
||||||
"Config.h": ["Config.h"],
|
"Config.h": ["Config.h"],
|
||||||
|
@ -152,6 +152,7 @@ endif()
|
|||||||
set(AT_MKLDNN_ACL_ENABLED 0)
|
set(AT_MKLDNN_ACL_ENABLED 0)
|
||||||
set(AT_MKLDNN_ENABLED 0)
|
set(AT_MKLDNN_ENABLED 0)
|
||||||
set(AT_MKL_ENABLED 0)
|
set(AT_MKL_ENABLED 0)
|
||||||
|
set(AT_KLEIDIAI_ENABLED 0)
|
||||||
# setting default preferred BLAS options if not already present.
|
# setting default preferred BLAS options if not already present.
|
||||||
if(NOT INTERN_BUILD_MOBILE)
|
if(NOT INTERN_BUILD_MOBILE)
|
||||||
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
|
set(BLAS "MKL" CACHE STRING "Selected BLAS library")
|
||||||
@ -1480,6 +1481,35 @@ if(NOT INTERN_BUILD_MOBILE)
|
|||||||
message("disabling MKLDNN because USE_MKLDNN is not set")
|
message("disabling MKLDNN because USE_MKLDNN is not set")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(USE_KLEIDIAI)
|
||||||
|
if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
|
||||||
|
message(WARNING "KleidiAI: Using non-supported Clang version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
|
||||||
|
endif()
|
||||||
|
if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS "11" )
|
||||||
|
message(WARNING "KleidiAI: Using non-supported GCC version. Expected 11 or newer, received ${CMAKE_C_COMPILER_VERSION}.")
|
||||||
|
endif()
|
||||||
|
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||||
|
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
|
||||||
|
set(AT_KLEIDIAI_ENABLED 1)
|
||||||
|
set(KLEIDIAI_BUILD_TESTS OFF) # Disable building KLEIDIAI tests
|
||||||
|
set(KLEIDIAI_SRC "${PROJECT_SOURCE_DIR}/third_party/kleidiai")
|
||||||
|
add_subdirectory(${KLEIDIAI_SRC})
|
||||||
|
set(KLEIDIAI_INCLUDE_DIRS
|
||||||
|
${KLEIDIAI_SRC}/
|
||||||
|
${KLEIDIAI_SRC}/kai/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/
|
||||||
|
)
|
||||||
|
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
|
||||||
|
list(APPEND Caffe2_DEPENDENCY_LIBS kleidiai)
|
||||||
|
# Recover build options.
|
||||||
|
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(UNIX AND NOT APPLE)
|
if(UNIX AND NOT APPLE)
|
||||||
include(CheckLibraryExists)
|
include(CheckLibraryExists)
|
||||||
# https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830
|
# https://github.com/libgit2/libgit2/issues/2128#issuecomment-35649830
|
||||||
|
@ -153,6 +153,9 @@ function(caffe2_print_configuration_summary)
|
|||||||
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
|
message(STATUS " USE_MKLDNN_ACL : ${USE_MKLDNN_ACL}")
|
||||||
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
|
message(STATUS " USE_MKLDNN_CBLAS : ${USE_MKLDNN_CBLAS}")
|
||||||
endif()
|
endif()
|
||||||
|
if(${USE_KLEIDIAI})
|
||||||
|
message(STATUS " USE_KLEIDIAI : ${USE_KLEIDIAI}")
|
||||||
|
endif()
|
||||||
message(STATUS " USE_UCC : ${USE_UCC}")
|
message(STATUS " USE_UCC : ${USE_UCC}")
|
||||||
if(${USE_UCC})
|
if(${USE_UCC})
|
||||||
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
|
message(STATUS " USE_SYSTEM_UCC : ${USE_SYSTEM_UCC}")
|
||||||
|
@ -97,6 +97,10 @@ else()
|
|||||||
append_torchlib_if_found(microkernels-prod)
|
append_torchlib_if_found(microkernels-prod)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(@USE_KLEIDIAI@)
|
||||||
|
append_torchlib_if_found(kleidiai)
|
||||||
|
endif()
|
||||||
|
|
||||||
append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
|
append_torchlib_if_found(caffe2_protos protobuf-lite protobuf protoc)
|
||||||
append_torchlib_if_found(onnx onnx_proto)
|
append_torchlib_if_found(onnx onnx_proto)
|
||||||
|
|
||||||
|
@ -203,6 +203,7 @@ torch.backends.openmp
|
|||||||
.. add anything to the rendered page for now.
|
.. add anything to the rendered page for now.
|
||||||
.. py:module:: torch.backends.quantized
|
.. py:module:: torch.backends.quantized
|
||||||
.. py:module:: torch.backends.xnnpack
|
.. py:module:: torch.backends.xnnpack
|
||||||
|
.. py:module:: torch.backends.kleidiai
|
||||||
|
|
||||||
|
|
||||||
torch.backends.opt_einsum
|
torch.backends.opt_einsum
|
||||||
|
1
setup.py
1
setup.py
@ -1221,6 +1221,7 @@ def main():
|
|||||||
"include/ATen/native/cuda/*.cuh",
|
"include/ATen/native/cuda/*.cuh",
|
||||||
"include/ATen/native/hip/*.h",
|
"include/ATen/native/hip/*.h",
|
||||||
"include/ATen/native/hip/*.cuh",
|
"include/ATen/native/hip/*.cuh",
|
||||||
|
"include/ATen/native/kleidiai/*.h",
|
||||||
"include/ATen/native/mps/*.h",
|
"include/ATen/native/mps/*.h",
|
||||||
"include/ATen/native/mkldnn/xpu/*.h",
|
"include/ATen/native/mkldnn/xpu/*.h",
|
||||||
"include/ATen/native/mkldnn/xpu/detail/*.h",
|
"include/ATen/native/mkldnn/xpu/detail/*.h",
|
||||||
|
@ -92,6 +92,8 @@ aten::_dimI
|
|||||||
aten::_dimV
|
aten::_dimV
|
||||||
aten::_dirichlet_grad
|
aten::_dirichlet_grad
|
||||||
aten::_dirichlet_grad.out
|
aten::_dirichlet_grad.out
|
||||||
|
aten::_dyn_quant_matmul_4bit
|
||||||
|
aten::_dyn_quant_pack_4bit_weight
|
||||||
aten::_efficient_attention_backward
|
aten::_efficient_attention_backward
|
||||||
aten::_efficient_attention_forward
|
aten::_efficient_attention_forward
|
||||||
aten::_efficientzerotensor
|
aten::_efficientzerotensor
|
||||||
|
@ -76,6 +76,7 @@ from torch.testing._internal.common_device_type import (
|
|||||||
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
|
from torch.testing._internal.common_dtype import all_types, get_all_dtypes
|
||||||
from torch.testing._internal.common_quantization import (
|
from torch.testing._internal.common_quantization import (
|
||||||
_dynamically_quantize_per_channel,
|
_dynamically_quantize_per_channel,
|
||||||
|
_group_quantize_tensor_symmetric,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
DeterministicGuard,
|
DeterministicGuard,
|
||||||
@ -2224,6 +2225,85 @@ class CommonTemplate:
|
|||||||
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
b_int8pack, b_scales = convert_weight_to_int8pack(b)
|
||||||
self.common(fn, (a, b_int8pack, b_scales, c))
|
self.common(fn, (a, b_int8pack, b_scales, c))
|
||||||
|
|
||||||
|
@xfail_if_triton_cpu
|
||||||
|
@skipCUDAIf(True, "No _dyn_quant_pack_4bit_weight implementation on CUDA")
|
||||||
|
@skipIfRocm
|
||||||
|
@skipIfXpu(msg="No _dyn_quant_pack_4bit_weight implementation on XPU")
|
||||||
|
def test__dyn_quant_pack_4bit_weight(self):
|
||||||
|
q_group = 32
|
||||||
|
k = 128
|
||||||
|
n = 128
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
b = torch.rand((k, n), dtype=torch.float32)
|
||||||
|
in_features = b.size(0)
|
||||||
|
out_features = b.size(1)
|
||||||
|
|
||||||
|
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||||
|
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||||
|
b, n_bit=4, groupsize=q_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_group == in_features:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||||
|
else:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||||
|
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
return b_int4pack, b_scales_and_zeros
|
||||||
|
|
||||||
|
def fn(b, in_features, out_features):
|
||||||
|
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||||
|
return b_int4pack
|
||||||
|
|
||||||
|
self.common(fn, (b, in_features, out_features))
|
||||||
|
|
||||||
|
@xfail_if_triton_cpu
|
||||||
|
@skipCUDAIf(True, "No _dyn_quant_matmul_4bit implementation on CUDA")
|
||||||
|
@skipIfRocm
|
||||||
|
@skipIfXpu(msg="No _dyn_quant_matmul_4bit implementation on XPU")
|
||||||
|
def test__dyn_quant_matmul_4bit(self):
|
||||||
|
q_group = 32
|
||||||
|
m = 32
|
||||||
|
k = 128
|
||||||
|
n = 128
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
a = torch.rand((m, k), dtype=torch.float32)
|
||||||
|
b = torch.rand((k, n), dtype=torch.float32)
|
||||||
|
in_features = b.size(0)
|
||||||
|
out_features = b.size(1)
|
||||||
|
|
||||||
|
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||||
|
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||||
|
b, n_bit=4, groupsize=q_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_group == in_features:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||||
|
else:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||||
|
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
return b_int4pack, b_scales_and_zeros
|
||||||
|
|
||||||
|
def fn(a, q_group, in_features, out_features):
|
||||||
|
b_int4pack, _ = dyn_quant_pack_4bit_weight(b, in_features, out_features)
|
||||||
|
res = torch._dyn_quant_matmul_4bit(
|
||||||
|
a,
|
||||||
|
b_int4pack,
|
||||||
|
q_group,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
self.common(fn, (a, q_group, in_features, out_features))
|
||||||
|
|
||||||
def test_expanded_reduction(self):
|
def test_expanded_reduction(self):
|
||||||
def fn(x, y):
|
def fn(x, y):
|
||||||
z = x * y
|
z = x * y
|
||||||
|
@ -34,7 +34,8 @@ from torch.testing._internal.common_dtype import (
|
|||||||
)
|
)
|
||||||
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
|
from torch.testing._internal.common_cuda import SM53OrLater, SM80OrLater, SM90OrLater, tf32_on_and_off, _get_magma_version, \
|
||||||
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
|
_get_torch_cuda_version, CDNA2OrLater, TEST_MULTIGPU
|
||||||
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel
|
from torch.testing._internal.common_quantization import _group_quantize_tensor, _dynamically_quantize_per_channel, \
|
||||||
|
_group_quantize_tensor_symmetric
|
||||||
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||||
from torch.distributions.binomial import Binomial
|
from torch.distributions.binomial import Binomial
|
||||||
import torch.backends.opt_einsum as opt_einsum
|
import torch.backends.opt_einsum as opt_einsum
|
||||||
@ -909,7 +910,6 @@ class TestLinalg(TestCase):
|
|||||||
torch.randn((3, 52, 52), device=device, dtype=dtype),
|
torch.randn((3, 52, 52), device=device, dtype=dtype),
|
||||||
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
|
torch.randn((4, 2, 26, 26), device=device, dtype=dtype))
|
||||||
|
|
||||||
|
|
||||||
ops = (torch.det, torch.Tensor.det,
|
ops = (torch.det, torch.Tensor.det,
|
||||||
torch.linalg.det)
|
torch.linalg.det)
|
||||||
for t in tensors:
|
for t in tensors:
|
||||||
@ -1438,7 +1438,6 @@ class TestLinalg(TestCase):
|
|||||||
continue
|
continue
|
||||||
run_test_case(make_arg(shape), ord, dim, keepdim)
|
run_test_case(make_arg(shape), ord, dim, keepdim)
|
||||||
|
|
||||||
|
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
@dtypes(torch.bfloat16, torch.float16)
|
@dtypes(torch.bfloat16, torch.float16)
|
||||||
def test_norm_fused_type_promotion(self, device, dtype):
|
def test_norm_fused_type_promotion(self, device, dtype):
|
||||||
@ -4344,7 +4343,6 @@ class TestLinalg(TestCase):
|
|||||||
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
|
triangular_solve_zero_batch_helper((batchsize, 5, 5), (batchsize, 5, 10),
|
||||||
upper, unitriangular, transpose)
|
upper, unitriangular, transpose)
|
||||||
|
|
||||||
|
|
||||||
@slowTest
|
@slowTest
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
@ -4465,7 +4463,6 @@ class TestLinalg(TestCase):
|
|||||||
self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
|
self.assertTrue("An output with one or more elements was resized" in str(w[0].message))
|
||||||
self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
|
self.assertTrue("An output with one or more elements was resized" in str(w[1].message))
|
||||||
|
|
||||||
|
|
||||||
def check_single_matmul(self, x, y):
|
def check_single_matmul(self, x, y):
|
||||||
|
|
||||||
def assertEqual(answer, expected):
|
def assertEqual(answer, expected):
|
||||||
@ -5701,7 +5698,6 @@ class TestLinalg(TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertEqual(B_, X_ @ A)
|
self.assertEqual(B_, X_ @ A)
|
||||||
|
|
||||||
|
|
||||||
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
|
sizes = ((3, 3), (5, 5), (4, 2), (3, 4), (0, 0), (0, 1), (1, 0))
|
||||||
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
|
batches = ((0,), (), (1,), (2,), (3,), (1, 0), (3, 5))
|
||||||
# Non pivoting just implemented for CUDA
|
# Non pivoting just implemented for CUDA
|
||||||
@ -5734,7 +5730,6 @@ class TestLinalg(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
|
with self.assertRaisesRegex(RuntimeError, 'LU without pivoting is not implemented on the CPU'):
|
||||||
f(torch.empty(1, 2, 2), pivot=False)
|
f(torch.empty(1, 2, 2), pivot=False)
|
||||||
|
|
||||||
|
|
||||||
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
|
@precisionOverride({torch.float32: 1e-2, torch.complex64: 1e-2})
|
||||||
@skipCUDAIfNoMagmaAndNoCusolver
|
@skipCUDAIfNoMagmaAndNoCusolver
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
@ -5762,7 +5757,6 @@ class TestLinalg(TestCase):
|
|||||||
for b, n in shapes:
|
for b, n in shapes:
|
||||||
yield make_arg((b, n, n)), make_arg((b, n, rhs))
|
yield make_arg((b, n, n)), make_arg((b, n, rhs))
|
||||||
|
|
||||||
|
|
||||||
for A, B in gen_matrices():
|
for A, B in gen_matrices():
|
||||||
LU, pivots = torch.linalg.lu_factor(A)
|
LU, pivots = torch.linalg.lu_factor(A)
|
||||||
for backend in backends:
|
for backend in backends:
|
||||||
@ -5777,7 +5771,6 @@ class TestLinalg(TestCase):
|
|||||||
else:
|
else:
|
||||||
self.assertEqual(B_left, X @ A_adj)
|
self.assertEqual(B_left, X @ A_adj)
|
||||||
|
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*floating_and_complex_types())
|
@dtypes(*floating_and_complex_types())
|
||||||
def test_linalg_lu_cpu_errors(self, device, dtype):
|
def test_linalg_lu_cpu_errors(self, device, dtype):
|
||||||
@ -5818,7 +5811,6 @@ class TestLinalg(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
||||||
torch.lu_unpack(LU, pivots)
|
torch.lu_unpack(LU, pivots)
|
||||||
|
|
||||||
|
|
||||||
# Rectangular tests
|
# Rectangular tests
|
||||||
sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
sample = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
||||||
B = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
B = torch.randn(2, 3, 5, device=device, dtype=dtype)
|
||||||
@ -5835,7 +5827,6 @@ class TestLinalg(TestCase):
|
|||||||
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
with self.assertRaisesRegex(RuntimeError, r"between 1 and LU.size\(-2\)."):
|
||||||
torch.lu_unpack(LU, pivots)
|
torch.lu_unpack(LU, pivots)
|
||||||
|
|
||||||
|
|
||||||
@skipCPUIfNoLapack
|
@skipCPUIfNoLapack
|
||||||
@skipCUDAIfNoMagma
|
@skipCUDAIfNoMagma
|
||||||
@dtypes(torch.double)
|
@dtypes(torch.double)
|
||||||
@ -6444,7 +6435,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
|
y = torch.baddbmm(input_tensor, batch1, batch2, beta=0.0, out=out)
|
||||||
self.assertEqual(out, y_ref)
|
self.assertEqual(out, y_ref)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
@onlyCUDA
|
@onlyCUDA
|
||||||
def test_matmul_45724(self, device):
|
def test_matmul_45724(self, device):
|
||||||
@ -6617,7 +6607,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
torch._int_mm(a_int8, b_int8, out=c_int32_result)
|
torch._int_mm(a_int8, b_int8, out=c_int32_result)
|
||||||
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
|
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@ -6680,7 +6669,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
mean_err = ((res - ref).abs() / ref).mean()
|
mean_err = ((res - ref).abs() / ref).mean()
|
||||||
self.assertTrue(mean_err < 0.05)
|
self.assertTrue(mean_err < 0.05)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
|
||||||
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@ -6733,6 +6721,168 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
mean_err = ((res - ref).abs() / ref).mean()
|
mean_err = ((res - ref).abs() / ref).mean()
|
||||||
self.assertTrue(mean_err < 0.05)
|
self.assertTrue(mean_err < 0.05)
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
|
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@parametrize("k", [64, 256])
|
||||||
|
@parametrize("n", [32, 48, 64, 128])
|
||||||
|
def test__dyn_quant_pack_4bit_weight(self, device, k, n):
|
||||||
|
# TODO: Fix https://github.com/pytorch/pytorch/issues/131425 and use OpInfo instead
|
||||||
|
# Weight shape is [K x N]
|
||||||
|
if self.device_type == "cuda":
|
||||||
|
self.skipTest("CUDA Backend is unsupported")
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
block_size = 32
|
||||||
|
b = torch.rand((k, n), dtype=torch.bfloat16, device=device)
|
||||||
|
in_features = b.size(0)
|
||||||
|
out_features = b.size(1)
|
||||||
|
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||||
|
b, n_bit=4, groupsize=block_size
|
||||||
|
)
|
||||||
|
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
|
||||||
|
)
|
||||||
|
b_int4pack_meta = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, block_size, in_features, out_features
|
||||||
|
)
|
||||||
|
self.assertEqual(b_int4pack.shape, b_int4pack_meta.shape)
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
|
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@parametrize("m", [1, 32])
|
||||||
|
@parametrize("k", [64, 128])
|
||||||
|
@parametrize("n", [4096, 11008])
|
||||||
|
def test__dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||||
|
if self.device_type == "cuda":
|
||||||
|
self.skipTest("CUDA is unsupported")
|
||||||
|
|
||||||
|
q_group = 32
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
|
||||||
|
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
|
||||||
|
in_features = b_float32.size(0)
|
||||||
|
out_features = b_float32.size(1)
|
||||||
|
|
||||||
|
def dyn_quant_pack_4bit_weight(b, in_features, out_features):
|
||||||
|
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||||
|
b, n_bit=4, groupsize=q_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_group == in_features:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.float)
|
||||||
|
else:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(torch.bfloat16)
|
||||||
|
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
return b_int4pack, b_scales_and_zeros
|
||||||
|
|
||||||
|
def dyn_quant_matmul_4bit(
|
||||||
|
a, b_int4pack, q_group, in_features, out_features
|
||||||
|
):
|
||||||
|
return torch._dyn_quant_matmul_4bit(
|
||||||
|
a,
|
||||||
|
b_int4pack,
|
||||||
|
q_group,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
b_int4pack, b_scales_and_zeros = dyn_quant_pack_4bit_weight(
|
||||||
|
b_float32, in_features, out_features
|
||||||
|
)
|
||||||
|
|
||||||
|
dtypes = [torch.float32]
|
||||||
|
|
||||||
|
for dtype in dtypes:
|
||||||
|
a = a_float32.to(dtype=dtype)
|
||||||
|
b = b_float32.to(dtype=dtype)
|
||||||
|
ref = torch.mm(a, b)
|
||||||
|
res = dyn_quant_matmul_4bit(
|
||||||
|
a,
|
||||||
|
b_int4pack,
|
||||||
|
q_group,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
mean_err = ((res - ref).abs() / ref).mean()
|
||||||
|
self.assertTrue(mean_err < 0.05)
|
||||||
|
elementwise_diff = (res - ref).abs()
|
||||||
|
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
|
||||||
|
min=torch.finfo(ref.dtype).eps
|
||||||
|
)
|
||||||
|
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
|
||||||
|
self.assertTrue(
|
||||||
|
all_elements_within_threshold, "Some elements have error >= 0.06"
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error")
|
||||||
|
@unittest.skipIf(TEST_WITH_ROCM and IS_REMOTE_GPU, "ROCM is unsupported")
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@parametrize("m", [1, 32])
|
||||||
|
@parametrize("k", [64, 128])
|
||||||
|
@parametrize("n", [4096, 11008])
|
||||||
|
def test_compile_dyn_quant_matmul_4bit(self, device, m, k, n):
|
||||||
|
if self.device_type == "cuda":
|
||||||
|
self.skipTest("CUDA is unsupported")
|
||||||
|
|
||||||
|
q_group = 32
|
||||||
|
|
||||||
|
torch.manual_seed(1)
|
||||||
|
a_float32 = torch.rand((m, k), dtype=torch.float32, device=device)
|
||||||
|
b_float32 = torch.rand((k, n), dtype=torch.float32, device=device)
|
||||||
|
in_features = b_float32.size(0)
|
||||||
|
out_features = b_float32.size(1)
|
||||||
|
|
||||||
|
b_uint8, b_scales_and_zeros = _group_quantize_tensor_symmetric(
|
||||||
|
b_float32, n_bit=4, groupsize=q_group
|
||||||
|
)
|
||||||
|
|
||||||
|
if q_group == in_features:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.float)
|
||||||
|
else:
|
||||||
|
b_scales_and_zeros = b_scales_and_zeros.to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
|
@torch.compile
|
||||||
|
def dyn_quant_matmul_4bit(
|
||||||
|
a, b_uint8, b_scales_and_zeros, q_group, in_features, out_features
|
||||||
|
):
|
||||||
|
b_int4pack = torch._dyn_quant_pack_4bit_weight(
|
||||||
|
b_uint8, b_scales_and_zeros, None, q_group, in_features, out_features
|
||||||
|
)
|
||||||
|
return torch._dyn_quant_matmul_4bit(
|
||||||
|
a,
|
||||||
|
b_int4pack,
|
||||||
|
q_group,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = dyn_quant_matmul_4bit(
|
||||||
|
a_float32,
|
||||||
|
b_uint8,
|
||||||
|
b_scales_and_zeros,
|
||||||
|
q_group,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
)
|
||||||
|
ref = torch.mm(a_float32, b_float32)
|
||||||
|
|
||||||
|
mean_err = ((res - ref).abs() / ref).mean()
|
||||||
|
self.assertTrue(mean_err < 0.05)
|
||||||
|
elementwise_diff = (res - ref).abs()
|
||||||
|
elementwise_relative_error = elementwise_diff / ref.abs().clamp(
|
||||||
|
min=torch.finfo(ref.dtype).eps
|
||||||
|
)
|
||||||
|
all_elements_within_threshold = torch.all(elementwise_relative_error < 0.06)
|
||||||
|
self.assertTrue(
|
||||||
|
all_elements_within_threshold, "Some elements have error >= 0.06"
|
||||||
|
)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@parametrize("m", [32, 64])
|
@parametrize("m", [32, 64])
|
||||||
@parametrize("k", [32, 64])
|
@parametrize("k", [32, 64])
|
||||||
@ -8664,8 +8814,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
|
|
||||||
self.assertEqual(ck_out, cpu_out)
|
self.assertEqual(ck_out, cpu_out)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_permute_matmul(self):
|
def test_permute_matmul(self):
|
||||||
a = torch.ones([2, 5, 24, 24])
|
a = torch.ones([2, 5, 24, 24])
|
||||||
b = torch.ones([3, 2, 5, 24, 24])
|
b = torch.ones([3, 2, 5, 24, 24])
|
||||||
@ -8745,7 +8893,6 @@ scipy_lobpcg | {eq_err_scipy:10.2e} | {eq_err_general_scipy:10.2e} | {iters2:
|
|||||||
ref = alpha * A @ B + beta * C
|
ref = alpha * A @ B + beta * C
|
||||||
self.assertEqual(rc, ref)
|
self.assertEqual(rc, ref)
|
||||||
|
|
||||||
|
|
||||||
@dtypes(torch.float, torch.double)
|
@dtypes(torch.float, torch.double)
|
||||||
@precisionOverride({torch.float32: 1e-4})
|
@precisionOverride({torch.float32: 1e-4})
|
||||||
def test_1_sized_with_0_strided(self, device, dtype):
|
def test_1_sized_with_0_strided(self, device, dtype):
|
||||||
|
1
third_party/kleidiai
vendored
Submodule
1
third_party/kleidiai
vendored
Submodule
Submodule third_party/kleidiai added at 202603f38a
@ -1299,6 +1299,7 @@ def _valgrind_toggle_and_dump_stats() -> None: ... # CALLGRIND_TOGGLE_COLLECT a
|
|||||||
|
|
||||||
has_openmp: _bool
|
has_openmp: _bool
|
||||||
has_mkl: _bool
|
has_mkl: _bool
|
||||||
|
_has_kleidiai: _bool
|
||||||
_has_mps: _bool
|
_has_mps: _bool
|
||||||
has_lapack: _bool
|
has_lapack: _bool
|
||||||
_has_cuda: _bool
|
_has_cuda: _bool
|
||||||
|
@ -1372,6 +1372,8 @@ torch_c_binding_in_graph_functions = dict.fromkeys(
|
|||||||
"torch._dim_arange",
|
"torch._dim_arange",
|
||||||
"torch._dirichlet_grad",
|
"torch._dirichlet_grad",
|
||||||
"torch._disable_functionalization",
|
"torch._disable_functionalization",
|
||||||
|
"torch._dyn_quant_matmul_4bit",
|
||||||
|
"torch._dyn_quant_pack_4bit_weight",
|
||||||
"torch._efficientzerotensor",
|
"torch._efficientzerotensor",
|
||||||
"torch._embedding_bag_forward_only",
|
"torch._embedding_bag_forward_only",
|
||||||
"torch._embedding_bag",
|
"torch._embedding_bag",
|
||||||
|
@ -94,3 +94,6 @@ def register_woq_mm_ops() -> None:
|
|||||||
return autotune_select_algorithm(
|
return autotune_select_algorithm(
|
||||||
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
|
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
|
||||||
)
|
)
|
||||||
|
|
||||||
|
lowering.make_fallback(aten._dyn_quant_matmul_4bit)
|
||||||
|
lowering.make_fallback(aten._dyn_quant_pack_4bit_weight)
|
||||||
|
@ -3344,6 +3344,161 @@ def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
|
|||||||
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def kai_roundup(a: int, b: int) -> int:
|
||||||
|
return ((a + b - 1) // b) * b
|
||||||
|
|
||||||
|
|
||||||
|
def get_kai_packed_weight_size(n_bits, N, K, groupsize):
|
||||||
|
if n_bits == 4:
|
||||||
|
if groupsize == K: # channelwise
|
||||||
|
# dotprod params only [1x8x32_neon_dotprod]
|
||||||
|
kai_nr = 8
|
||||||
|
kai_kr = 16
|
||||||
|
kai_sr = 2
|
||||||
|
kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
|
||||||
|
kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
|
||||||
|
kai_num_bytes_bias = 4 # sizeof(float)
|
||||||
|
|
||||||
|
def kai_k_roundedup(k, kr, sr):
|
||||||
|
# Since we pack a float and int32 value at the end of the row,
|
||||||
|
# we must make sure that k is a multiple of 4 for alignment
|
||||||
|
kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
|
||||||
|
return kai_roundup(k, kr_sr_roundedup4)
|
||||||
|
|
||||||
|
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||||
|
k, nr, kr, sr
|
||||||
|
):
|
||||||
|
k_internal = kai_k_roundedup(k, kr, sr)
|
||||||
|
|
||||||
|
assert (k_internal % 2) == 0, "k_internal must be even"
|
||||||
|
|
||||||
|
return nr * (
|
||||||
|
(k_internal // 2)
|
||||||
|
+ kai_num_bytes_multiplier_rhs
|
||||||
|
+ kai_num_bytes_sum_rhs
|
||||||
|
+ kai_num_bytes_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||||
|
n, k, nr, kr, sr
|
||||||
|
):
|
||||||
|
num_rows = kai_roundup(n, nr) // nr
|
||||||
|
|
||||||
|
return (
|
||||||
|
num_rows
|
||||||
|
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||||
|
k, nr, kr, sr
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
|
||||||
|
N, K, kai_nr, kai_kr, kai_sr
|
||||||
|
)
|
||||||
|
elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
|
||||||
|
kai_nr = 8
|
||||||
|
kai_kr = 16
|
||||||
|
kai_sr = 2
|
||||||
|
kai_num_bytes_sum_rhs = 4
|
||||||
|
kai_num_bytes_bias = 4
|
||||||
|
kai_nr_multiple_of = 4
|
||||||
|
kai_bl_multiple_of = 32
|
||||||
|
|
||||||
|
def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||||
|
n, k, nr, kr, sr, bl
|
||||||
|
):
|
||||||
|
assert (bl % kr) == 0
|
||||||
|
assert (nr % kai_nr_multiple_of) == 0
|
||||||
|
assert (bl % kai_bl_multiple_of) == 0
|
||||||
|
|
||||||
|
num_rows = kai_roundup(n, nr) // nr
|
||||||
|
|
||||||
|
return (
|
||||||
|
num_rows
|
||||||
|
* kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||||
|
k, nr, kr, sr, bl
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||||
|
k, nr, kr, sr, bl
|
||||||
|
):
|
||||||
|
assert (bl % kr) == 0
|
||||||
|
assert (nr % kai_nr_multiple_of) == 0
|
||||||
|
assert (bl % kai_bl_multiple_of) == 0
|
||||||
|
|
||||||
|
# kr and sr are unused in the calculation
|
||||||
|
num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
|
||||||
|
num_blocks_per_row = kai_num_blocks_per_row(k, bl)
|
||||||
|
num_bytes_per_block = kai_num_bytes_per_block(
|
||||||
|
bl, num_bytes_multiplier_rhs
|
||||||
|
)
|
||||||
|
|
||||||
|
return nr * (
|
||||||
|
(num_bytes_per_block * num_blocks_per_row)
|
||||||
|
+ kai_num_bytes_sum_rhs
|
||||||
|
+ kai_num_bytes_bias
|
||||||
|
)
|
||||||
|
|
||||||
|
# This funtion retuns size of these datatypes stored as enum. We modify it to just return bf16 datatype
|
||||||
|
# https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
|
||||||
|
def kai_get_bf16_datatype_size_in_bytes():
|
||||||
|
return 2 # 2 bytes
|
||||||
|
|
||||||
|
def kai_num_blocks_per_row(k, bl):
|
||||||
|
assert (bl % kai_bl_multiple_of) == 0
|
||||||
|
return kai_roundup(k, bl) // bl
|
||||||
|
|
||||||
|
def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
|
||||||
|
assert (bl % kai_bl_multiple_of) == 0
|
||||||
|
return (bl // 2) + num_bytes_multiplier_rhs
|
||||||
|
|
||||||
|
return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
|
||||||
|
N, K, kai_nr, kai_kr, kai_sr, groupsize
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta([aten._dyn_quant_pack_4bit_weight])
|
||||||
|
def meta__dyn_quant_pack_4bit_weight(
|
||||||
|
weights, scales_zeros, bias: Optional[Tensor], block_size, in_features, out_features
|
||||||
|
):
|
||||||
|
torch._check(
|
||||||
|
weights.dtype is torch.uint8,
|
||||||
|
lambda: f"expected w to be uint8, got {weights.dtype}",
|
||||||
|
)
|
||||||
|
if torch.backends.kleidiai.is_available() and (
|
||||||
|
(block_size == in_features and scales_zeros.dtype == torch.float)
|
||||||
|
or (
|
||||||
|
block_size < in_features
|
||||||
|
and block_size % 32 == 0
|
||||||
|
and in_features % block_size == 0
|
||||||
|
and scales_zeros.dtype == torch.bfloat16
|
||||||
|
)
|
||||||
|
):
|
||||||
|
packed_weight_size = get_kai_packed_weight_size(
|
||||||
|
4, out_features, in_features, block_size
|
||||||
|
)
|
||||||
|
return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
|
||||||
|
packed_weight_size = weights.numel() + scales_zeros.numel()
|
||||||
|
return weights.new_empty(packed_weight_size, dtype=torch.float)
|
||||||
|
|
||||||
|
|
||||||
|
@register_meta([aten._dyn_quant_matmul_4bit])
|
||||||
|
def meta__dyn_quant_matmul_4bit(
|
||||||
|
inp,
|
||||||
|
packed_weights,
|
||||||
|
block_size,
|
||||||
|
in_features,
|
||||||
|
out_features,
|
||||||
|
):
|
||||||
|
torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
|
||||||
|
torch._check(
|
||||||
|
inp.dtype in [torch.float32],
|
||||||
|
lambda: f"expected input to be f32, got {inp.dtype}",
|
||||||
|
)
|
||||||
|
M = inp.size(0)
|
||||||
|
return inp.new_empty(M, out_features, dtype=inp.dtype)
|
||||||
|
|
||||||
|
|
||||||
@register_meta([aten._weight_int8pack_mm])
|
@register_meta([aten._weight_int8pack_mm])
|
||||||
def meta__weight_int8pack_mm(x, w, q_scales):
|
def meta__weight_int8pack_mm(x, w, q_scales):
|
||||||
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
|
torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
|
||||||
|
@ -62,6 +62,7 @@ from torch.backends import (
|
|||||||
cuda as cuda,
|
cuda as cuda,
|
||||||
cudnn as cudnn,
|
cudnn as cudnn,
|
||||||
cusparselt as cusparselt,
|
cusparselt as cusparselt,
|
||||||
|
kleidiai as kleidiai,
|
||||||
mha as mha,
|
mha as mha,
|
||||||
mkl as mkl,
|
mkl as mkl,
|
||||||
mkldnn as mkldnn,
|
mkldnn as mkldnn,
|
||||||
|
7
torch/backends/kleidiai/__init__.py
Normal file
7
torch/backends/kleidiai/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
# mypy: allow-untyped-defs
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def is_available():
|
||||||
|
r"""Return whether PyTorch is built with KleidiAI support."""
|
||||||
|
return torch._C._has_kleidiai
|
@ -1939,6 +1939,8 @@ Call this whenever a new thread is created in order to propagate values from
|
|||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
|
set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
|
||||||
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
|
ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
|
||||||
|
ASSERT_TRUE(
|
||||||
|
set_module_attr("_has_kleidiai", at::hasKleidiAI() ? Py_True : Py_False));
|
||||||
ASSERT_TRUE(
|
ASSERT_TRUE(
|
||||||
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
|
set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
|
||||||
|
|
||||||
|
@ -18,6 +18,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__adaptive_avg_pool3d_backward(At
|
|||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__addmm_activation(AtenTensorHandle self, AtenTensorHandle mat1, AtenTensorHandle mat2, double beta, double alpha, int32_t use_gelu, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_backward(AtenTensorHandle grad, AtenTensorHandle x1, AtenTensorHandle x2, double p, AtenTensorHandle cdist, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__cdist_forward(AtenTensorHandle x1, AtenTensorHandle x2, double p, int64_t* compute_mode, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_matmul_4bit(AtenTensorHandle inp, AtenTensorHandle packed_weights, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
|
||||||
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__dyn_quant_pack_4bit_weight(AtenTensorHandle weights, AtenTensorHandle scales_zeros, AtenTensorHandle* bias, int64_t block_size, int64_t in_features, int64_t out_features, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__efficientzerotensor(const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag(AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, int32_t scale_grad_by_freq, int64_t mode, int32_t sparse, AtenTensorHandle* per_sample_weights, int32_t include_last_offset, int64_t padding_idx, AtenTensorHandle* ret0, AtenTensorHandle* ret1, AtenTensorHandle* ret2, AtenTensorHandle* ret3);
|
||||||
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__embedding_bag_dense_backward(AtenTensorHandle grad, AtenTensorHandle indices, AtenTensorHandle offset2bag, AtenTensorHandle bag_size, AtenTensorHandle maximum_indices, int64_t num_weights, int32_t scale_grad_by_freq, int64_t mode, AtenTensorHandle* per_sample_weights, int64_t padding_idx, AtenTensorHandle* ret0);
|
||||||
|
@ -498,6 +498,39 @@ def _group_quantize_tensor(w, n_bit=4, q_group_size=16):
|
|||||||
return out, scales_and_zeros
|
return out, scales_and_zeros
|
||||||
|
|
||||||
|
|
||||||
|
def _group_quantize_tensor_symmetric(
|
||||||
|
w, n_bit=4, groupsize=32
|
||||||
|
):
|
||||||
|
# W is of shape [K x N]
|
||||||
|
# We transpose W as Quantization is applied on [N x K]
|
||||||
|
w = w.transpose(0, 1).contiguous()
|
||||||
|
assert w.dim() == 2
|
||||||
|
assert groupsize > 1
|
||||||
|
assert w.shape[-1] % groupsize == 0
|
||||||
|
# Calculate scale and zeros
|
||||||
|
to_quant = w.reshape(-1, groupsize)
|
||||||
|
max_val = to_quant.abs().amax(dim=1, keepdim=True)
|
||||||
|
eps = torch.finfo(max_val.dtype).eps
|
||||||
|
max_int = 2 ** (n_bit - 1) - 1 # For 4-bit, this is 7
|
||||||
|
scales = max_val.clamp(min=eps) / max_int
|
||||||
|
zeros = torch.zeros_like(scales)
|
||||||
|
|
||||||
|
# Quantize the weight
|
||||||
|
scales = scales.to(torch.float32).reshape(w.shape[0], -1)
|
||||||
|
zeros = zeros.to(torch.float32).reshape(w.shape[0], -1)
|
||||||
|
scales = scales.reshape(-1, 1)
|
||||||
|
zeros = zeros.reshape(-1, 1)
|
||||||
|
max_int = 2**n_bit - 1
|
||||||
|
w_int8 = to_quant.div(scales).add(8.5).to(torch.int8).clamp(max=max_int)
|
||||||
|
# We pack 2 signed int4 values in unsigned uint8 container.
|
||||||
|
# This reduces the weight size by half and improves load perf
|
||||||
|
out_uint8 = (w_int8[::, 1::2] << 4 | w_int8[::, ::2]).to(torch.uint8)
|
||||||
|
|
||||||
|
scales_and_zeros = scales.squeeze().contiguous()
|
||||||
|
|
||||||
|
return out_uint8, scales_and_zeros
|
||||||
|
|
||||||
|
|
||||||
def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
||||||
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
|
# source: https://github.com/pytorch-labs/gpt-fast/blob/main/quantize.py
|
||||||
# default setup for affine quantization of activations
|
# default setup for affine quantization of activations
|
||||||
@ -530,7 +563,6 @@ def _dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
|
|||||||
return quant, scales.to(x_dtype), zero_points
|
return quant, scales.to(x_dtype), zero_points
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# QuantizationTestCase used as a base class for testing quantization on modules
|
# QuantizationTestCase used as a base class for testing quantization on modules
|
||||||
class QuantizationTestCase(TestCase):
|
class QuantizationTestCase(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
@ -45,6 +45,8 @@ inductor_fallback_ops = {
|
|||||||
"aten.cummin.default",
|
"aten.cummin.default",
|
||||||
"aten.cumprod.default",
|
"aten.cumprod.default",
|
||||||
"aten.cumsum.default",
|
"aten.cumsum.default",
|
||||||
|
"aten._dyn_quant_matmul_4bit.default",
|
||||||
|
"aten._dyn_quant_pack_4bit_weight.default",
|
||||||
"aten._efficient_attention_backward.default",
|
"aten._efficient_attention_backward.default",
|
||||||
"aten._efficient_attention_forward.default",
|
"aten._efficient_attention_forward.default",
|
||||||
"aten._efficientzerotensor.default",
|
"aten._efficientzerotensor.default",
|
||||||
|
Reference in New Issue
Block a user