mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Quant][X86] add ops to compute uint8 pointwise add/add_relu (#152411)
**Summary** This PR adds two new ops, `onednn.qadd.tensor` and `onednn.qadd_relu.tensor`, for int8 elementwise add, which accepts inputs on CPU device (instead of QuantizedCPU). The new ops are implemented with AVX512 instructions and it provides similar or better performance, depending on shape, than its counterpart for QuantizedCPU device `quantized.add` and `quantized.add_relu`. The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported). **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k test_int8_add_onednn ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152411 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
a762dd1f67
commit
55784be01b
@ -3167,6 +3167,40 @@ class TestQuantizedOps(TestCase):
|
||||
c = torch.ops.onednn.qmul.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
|
||||
self.assertEqual(c, c_ref)
|
||||
|
||||
@skipIfNoONEDNN
|
||||
@given(relu_fused=st.booleans())
|
||||
def test_int8_add_onednn(self, relu_fused):
|
||||
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
|
||||
shape_list = [(16, 64), (15, 63)]
|
||||
cases = itertools.product(shape_list, output_dtype_list)
|
||||
for shape, output_dtype in cases:
|
||||
a = torch.randn(shape)
|
||||
b = torch.randn(shape)
|
||||
s_a, z_a = 0.1, 1
|
||||
s_b, z_b = 0.2, 2
|
||||
if output_dtype == torch.uint8:
|
||||
s_c, z_c = 0.3, 3
|
||||
else:
|
||||
s_c, z_c = 1, 0
|
||||
qa = torch.quantize_per_tensor(a, s_a, z_a, torch.quint8)
|
||||
qb = torch.quantize_per_tensor(b, s_b, z_b, torch.quint8)
|
||||
dqa = qa.dequantize()
|
||||
dqb = qb.dequantize()
|
||||
c_ref = dqa + dqb
|
||||
if relu_fused:
|
||||
c_ref = torch.nn.functional.relu(c_ref)
|
||||
if output_dtype == torch.uint8:
|
||||
c_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(c_ref, s_c, z_c, 0, 255, torch.uint8)
|
||||
c_ref = c_ref.to(output_dtype)
|
||||
|
||||
a_int8 = qa.int_repr()
|
||||
b_int8 = qb.int_repr()
|
||||
if relu_fused:
|
||||
c = torch.ops.onednn.qadd_relu.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
|
||||
else:
|
||||
c = torch.ops.onednn.qadd.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
|
||||
self.assertEqual(c, c_ref)
|
||||
|
||||
|
||||
class TestDynamicQuantizedOps(TestCase):
|
||||
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
|
||||
|
Reference in New Issue
Block a user