Enable s8s8s8 for qlinear with mkl-dnn (#139887)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139887
Approved by: https://github.com/huydhn
This commit is contained in:
Annop Wongwathanarat
2025-01-15 23:20:07 +00:00
committed by PyTorch MergeBot
parent 4e1834f5f3
commit 7265dc0622
2 changed files with 21 additions and 12 deletions

View File

@ -796,14 +796,15 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
TORCH_CHECK(
dim != 0,
"qlinear (ONEDNN): input dim should be at least 1, but got 0");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
"qlinear (ONEDNN): data type of input should be QUint8.");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8 || input.scalar_type() == c10::ScalarType::QInt8,
"qlinear (ONEDNN): data type of input should be QUInt8 or QInt8.");
auto is_input_qint8 = input.scalar_type() == c10::ScalarType::QInt8;
auto input_contig = input.expect_contiguous();
auto& w = *(weight_.get());
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
auto input_dims = {M, K};
auto input_data_type = dnnl::memory::data_type::u8;
auto input_data_type = is_input_qint8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::u8;
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
ideep::attr_t op_attr = ideep::attr_t();
if (post_op == Relu) {
@ -813,7 +814,7 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
} else if (post_op == Tanh) {
op_attr = ideep::attr_t::fuse_tanh();
}
ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
ideep::tensor x(input_desc, input_contig->data_ptr());
auto dst_dims = {M, N};
double input_scale = input.q_scale();
int64_t input_zero_point = input.q_zero_point();
@ -827,13 +828,15 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
// Allocate output Tensor
at::Tensor output = at::_empty_affine_quantized(
dst_dims,
at::device(c10::kCPU).dtype(c10::kQUInt8),
at::device(c10::kCPU).dtype(is_input_qint8 ? c10::kQInt8 : c10::kQUInt8),
output_scale,
output_zero_point);
if (output.numel() == 0) {
return output;
}
ideep::tensor y({dst_dims, ideep::tensor::data_type::u8,
auto output_ideep_data_type = is_input_qint8 ? ideep::tensor::data_type::s8 : ideep::tensor::data_type::u8;
auto ideep_lowp_kind = is_input_qint8 ? ideep::s8s8 : ideep::u8s8;
ideep::tensor y({dst_dims, output_ideep_data_type,
{output.strides().cbegin(), output.strides().cend()}},
output.data_ptr());
bool with_bias = bias_.has_value();
@ -855,7 +858,9 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
ideep::matmul_forward::prepare</*is_dynamic=*/false>(
params, x, w, b, y,
src_scales, weights_scales, dst_scales,
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr,
output_ideep_data_type,
ideep_lowp_kind);
get_cache() = LinearPrimitiveCache(cache_key, params);
w = w.reorder_if_differ_in(params.pd.weights_desc());
});
@ -865,7 +870,9 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
} else {
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
dst_scales, src_zero_point, dst_zero_point,
1.0f, 1.0f, op_attr);
1.0f, 1.0f, op_attr,
output_ideep_data_type,
ideep_lowp_kind);
}
auto out_sizes = input.sizes().vec();
out_sizes.back() = N;

View File

@ -25,7 +25,7 @@ hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE, IS_FBCODE
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE, IS_FBCODE, IS_ARM64
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
@ -3829,10 +3829,12 @@ class TestQuantizedLinear(TestCase):
if torch.backends.xnnpack.enabled:
dtypes.append(torch.qint8)
if qengine_is_onednn() and IS_ARM64:
dtypes.append(torch.qint8)
for dtype in dtypes:
# No support for channelwise in xnnpack (int8)
# ONEDNN does not support qint8
if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
if dtype == torch.qint8 and use_channelwise:
return
nptype = np_dtype[dtype]
@ -3878,7 +3880,7 @@ class TestQuantizedLinear(TestCase):
np.random.rand(output_channels) *
(b_value_max - b_value_min) + b_value_min
).astype(np.int32) if use_bias else None
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn') and not IS_ARM64:
avoid_vpmaddubsw_overflow_linear(
batch_size,
input_channels,