Enable optimized dynamic quantization on aarch64 (#126687)

oneDNN+ACL has optimized kernels for s8s8 matmul, so input is signed. This change leaves behaviour on all other platforms the same. This change requires https://github.com/intel/ideep/pull/313 to go in, and oneDNN 3.5 for the optimized kernels. This change speeds up dynamic quantized linear by ~10x.

Also, do you have a policy on copyright headers? Arm's usual policy when contributing to open source projects is to include a copyright header on any file which is modified. Would this be acceptable? If not, is there somewhere else suitable to note copyright?

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126687
Approved by: https://github.com/jgong5, https://github.com/malfet, https://github.com/snadampal

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
This commit is contained in:
Jonathan Deakin
2024-08-24 18:40:12 +00:00
committed by PyTorch MergeBot
parent f71c3d265a
commit 50d5aa8c10

View File

@ -26,6 +26,7 @@
#include <algorithm>
#include <string>
#include <type_traits>
int register_linear_params();
@ -530,12 +531,19 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
x_min = t_min.item<float>();
}
#endif
const int precision = 8;
#if defined(__aarch64__) && AT_MKLDNN_ACL_ENABLED()
// oneDNN+ACL has optimized kernels for s8s8 matmul, so input is signed
using input_qtype = int8_t;
#else
using input_qtype = uint8_t;
#endif
auto q_params = quant_utils::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/0,
/*qmax=*/(1 << precision) - 1,
/*qmin=*/std::numeric_limits<input_qtype>::min(),
/*qmax=*/std::numeric_limits<input_qtype>::max(),
/*preserve_sparsity=*/false,
/*force_scale_power_of_two=*/false,
/*reduce_range=*/reduce_range);
@ -573,7 +581,8 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
ideep::matmul_forward::prepare</*is_dynamic=*/true>(
params, x, w, b, y,
src_scales, weights_scales, ideep::scale_t(),
src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr);
src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr,
ideep::tensor::data_type::f32, std::is_signed_v<input_qtype> ? ideep::s8s8 : ideep::u8s8);
get_cache() = LinearPrimitiveCache(cache_key, params);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});