[Quant][X86] add an op to compute uint8 batch norm 2d (#152811)

**Summary**
This PR adds a new op, `onednn.qbatch_norm2d`, which accepts uint8 inputs on CPU device (instead of QuantizedCPU).
The new ops are implemented with AVX512 instructions and it provides similar performance as its counterpart for QuantizedCPU device `quantized.batch_norm2d`.
The new op supports output dtypes other than uint8 (fp32, fp16 and bf16 are supported).

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k test_int8_batch_norm_onednn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152811
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168, https://github.com/jgong5
ghstack dependencies: #152411
This commit is contained in:
Xia, Weiwen
2025-05-14 19:13:13 -07:00
committed by PyTorch MergeBot
parent 7e16cb99b6
commit 1a722f62c2
5 changed files with 289 additions and 0 deletions

View File

@ -3201,6 +3201,43 @@ class TestQuantizedOps(TestCase):
c = torch.ops.onednn.qadd.tensor(a_int8, s_a, z_a, b_int8, s_b, z_b, s_c, z_c, output_dtype)
self.assertEqual(c, c_ref)
@skipIfNoONEDNN
def test_int8_batch_norm_onednn(self):
# hypothesis too slow for this test, create test cases manually
channel_len_list = (8, 64, 100, 120, 128)
output_dtype_list = [torch.uint8, torch.float, torch.bfloat16, torch.half]
x_scale, x_zero_point = 0.1, 1
cases = itertools.product(channel_len_list, output_dtype_list)
for channels, out_dtype in cases:
shapes = [8, channels, 8, 8]
y_scale, y_zero_point = (0.2, 2) if out_dtype == torch.uint8 else (1, 0)
x = torch.randn(shapes, dtype=torch.float32)
mean = torch.rand(channels).float()
var = torch.rand(channels).float()
weight = torch.rand(channels).float()
bias = torch.rand(channels).float()
eps = 0.001
qx = torch.ops.quantized_decomposed.quantize_per_tensor.default(
x, x_scale, x_zero_point, 0, 255, torch.uint8
)
y = torch.ops.onednn.qbatch_norm2d(
qx, x_scale, x_zero_point, weight, bias, mean, var, eps, y_scale, y_zero_point, out_dtype
)
dqx = torch.ops.quantized_decomposed.dequantize_per_tensor.default(
qx, x_scale, x_zero_point, 0, 255, torch.uint8
)
y_ref = F.batch_norm(dqx, weight=weight, bias=bias,
running_mean=mean, running_var=var, training=False,
momentum=0, eps=eps)
if out_dtype == torch.uint8:
y_ref = torch.ops.quantized_decomposed.quantize_per_tensor.default(
y_ref, y_scale, y_zero_point, 0, 255, torch.uint8
)
y_ref = y_ref.to(out_dtype)
self.assertEqual(y, y_ref, msg=f"{y} vs {y_ref}")
class TestDynamicQuantizedOps(TestCase):
"""Tests the correctness of the dynamic quantized linear and linear_relu op."""