mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable s8s8s8 for qlinear with mkl-dnn (#139887)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139887 Approved by: https://github.com/huydhn
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e1834f5f3
commit
7265dc0622
@ -796,14 +796,15 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||
TORCH_CHECK(
|
||||
dim != 0,
|
||||
"qlinear (ONEDNN): input dim should be at least 1, but got 0");
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8,
|
||||
"qlinear (ONEDNN): data type of input should be QUint8.");
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::QUInt8 || input.scalar_type() == c10::ScalarType::QInt8,
|
||||
"qlinear (ONEDNN): data type of input should be QUInt8 or QInt8.");
|
||||
|
||||
auto is_input_qint8 = input.scalar_type() == c10::ScalarType::QInt8;
|
||||
auto input_contig = input.expect_contiguous();
|
||||
auto& w = *(weight_.get());
|
||||
auto K = input.size(dim - 1), M = input.numel() / K, N = w.get_dim(1);
|
||||
auto input_dims = {M, K};
|
||||
auto input_data_type = dnnl::memory::data_type::u8;
|
||||
auto input_data_type = is_input_qint8 ? dnnl::memory::data_type::s8 : dnnl::memory::data_type::u8;
|
||||
auto input_desc = ideep::tensor::desc(input_dims, input_data_type);
|
||||
ideep::attr_t op_attr = ideep::attr_t();
|
||||
if (post_op == Relu) {
|
||||
@ -813,7 +814,7 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||
} else if (post_op == Tanh) {
|
||||
op_attr = ideep::attr_t::fuse_tanh();
|
||||
}
|
||||
ideep::tensor x(input_desc, input_contig->data_ptr<c10::quint8>());
|
||||
ideep::tensor x(input_desc, input_contig->data_ptr());
|
||||
auto dst_dims = {M, N};
|
||||
double input_scale = input.q_scale();
|
||||
int64_t input_zero_point = input.q_zero_point();
|
||||
@ -827,13 +828,15 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||
// Allocate output Tensor
|
||||
at::Tensor output = at::_empty_affine_quantized(
|
||||
dst_dims,
|
||||
at::device(c10::kCPU).dtype(c10::kQUInt8),
|
||||
at::device(c10::kCPU).dtype(is_input_qint8 ? c10::kQInt8 : c10::kQUInt8),
|
||||
output_scale,
|
||||
output_zero_point);
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
ideep::tensor y({dst_dims, ideep::tensor::data_type::u8,
|
||||
auto output_ideep_data_type = is_input_qint8 ? ideep::tensor::data_type::s8 : ideep::tensor::data_type::u8;
|
||||
auto ideep_lowp_kind = is_input_qint8 ? ideep::s8s8 : ideep::u8s8;
|
||||
ideep::tensor y({dst_dims, output_ideep_data_type,
|
||||
{output.strides().cbegin(), output.strides().cend()}},
|
||||
output.data_ptr());
|
||||
bool with_bias = bias_.has_value();
|
||||
@ -855,7 +858,9 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||
ideep::matmul_forward::prepare</*is_dynamic=*/false>(
|
||||
params, x, w, b, y,
|
||||
src_scales, weights_scales, dst_scales,
|
||||
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
|
||||
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr,
|
||||
output_ideep_data_type,
|
||||
ideep_lowp_kind);
|
||||
get_cache() = LinearPrimitiveCache(cache_key, params);
|
||||
w = w.reorder_if_differ_in(params.pd.weights_desc());
|
||||
});
|
||||
@ -865,7 +870,9 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
|
||||
} else {
|
||||
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
|
||||
dst_scales, src_zero_point, dst_zero_point,
|
||||
1.0f, 1.0f, op_attr);
|
||||
1.0f, 1.0f, op_attr,
|
||||
output_ideep_data_type,
|
||||
ideep_lowp_kind);
|
||||
}
|
||||
auto out_sizes = input.sizes().vec();
|
||||
out_sizes.back() = N;
|
||||
|
@ -25,7 +25,7 @@ hu.assert_deadline_disabled()
|
||||
|
||||
from torch.testing._internal.common_cuda import SM80OrLater
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE, IS_FBCODE
|
||||
from torch.testing._internal.common_utils import IS_PPC, TEST_WITH_UBSAN, IS_MACOS, IS_SANDCASTLE, IS_FBCODE, IS_ARM64
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
|
||||
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
||||
override_quantized_engine, supported_qengines, override_qengines, _snr
|
||||
@ -3829,10 +3829,12 @@ class TestQuantizedLinear(TestCase):
|
||||
if torch.backends.xnnpack.enabled:
|
||||
dtypes.append(torch.qint8)
|
||||
|
||||
if qengine_is_onednn() and IS_ARM64:
|
||||
dtypes.append(torch.qint8)
|
||||
|
||||
for dtype in dtypes:
|
||||
# No support for channelwise in xnnpack (int8)
|
||||
# ONEDNN does not support qint8
|
||||
if dtype == torch.qint8 and (use_channelwise or qengine_is_onednn()):
|
||||
if dtype == torch.qint8 and use_channelwise:
|
||||
return
|
||||
|
||||
nptype = np_dtype[dtype]
|
||||
@ -3878,7 +3880,7 @@ class TestQuantizedLinear(TestCase):
|
||||
np.random.rand(output_channels) *
|
||||
(b_value_max - b_value_min) + b_value_min
|
||||
).astype(np.int32) if use_bias else None
|
||||
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn'):
|
||||
if torch.backends.quantized.engine in ('x86', 'fbgemm', 'onednn') and not IS_ARM64:
|
||||
avoid_vpmaddubsw_overflow_linear(
|
||||
batch_size,
|
||||
input_channels,
|
||||
|
Reference in New Issue
Block a user