mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][X86] add an op to compute uint8 batch norm 2d (#152811)
**Summary** This PR adds a new op, `onednn.qbatch_norm2d`, which accepts uint8 inputs on CPU device (instead of QuantizedCPU). The new ops are implemented with AVX512 instructions and it provides similar performance as its counterpart for QuantizedCPU device `quantized.batch_norm2d`. The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported). **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k test_int8_batch_norm_onednn ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152811 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168, https://github.com/jgong5 ghstack dependencies: #152411
This commit is contained in:
committed by
PyTorch MergeBot
parent
7e16cb99b6
commit
1a722f62c2
@ -3201,6 +3201,43 @@ class TestQuantizedOps(TestCase):
|
||||
c = torch.ops.onednn.qadd.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
|
||||
self.assertEqual(c, c_ref)
|
||||
|
||||
@skipIfNoONEDNN
|
||||
def test_int8_batch_norm_onednn(self):
|
||||
# hypothesis too slow for this test, create test cases manually
|
||||
channel_len_list = (8, 64, 100, 120, 128)
|
||||
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
|
||||
x_scale, x_zero_point = 0.1, 1
|
||||
cases = itertools.product(channel_len_list, output_dtype_list)
|
||||
for channels, out_dtype in cases:
|
||||
shapes = [8, channels, 8, 8]
|
||||
y_scale, y_zero_point = (0.2, 2) if out_dtype == torch.uint8 else (1, 0)
|
||||
|
||||
x = torch.randn(shapes, dtype=torch.float32)
|
||||
mean = torch.rand(channels).float()
|
||||
var = torch.rand(channels).float()
|
||||
weight = torch.rand(channels).float()
|
||||
bias = torch.rand(channels).float()
|
||||
eps = 0.001
|
||||
qx = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
x, x_scale, x_zero_point, 0, 255, torch.uint8
|
||||
)
|
||||
y = torch.ops.onednn.qbatch_norm2d(
|
||||
qx, x_scale, x_zero_point, weight, bias, mean, var, eps, y_scale, y_zero_point, out_dtype
|
||||
)
|
||||
|
||||
dqx = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
|
||||
qx, x_scale, x_zero_point, 0, 255, torch.uint8
|
||||
)
|
||||
y_ref = F.batch_norm(dqx, weight=weight, bias=bias,
|
||||
running_mean=mean, running_var=var, training=False,
|
||||
momentum=0, eps=eps)
|
||||
if out_dtype == torch.uint8:
|
||||
y_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(
|
||||
y_ref, y_scale, y_zero_point, 0, 255, torch.uint8
|
||||
)
|
||||
y_ref = y_ref.to(out_dtype)
|
||||
self.assertEqual(y, y_ref, msg=f"{y} vs {y_ref}")
|
||||
|
||||
|
||||
class TestDynamicQuantizedOps(TestCase):
|
||||
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""
|
||||
|
Reference in New Issue
Block a user