[Quant][CPU] Enable fp8 qlinear (#155678)

**Summary**
Enable fp8 qlinear on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qlinear op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qlinear op is not changed either.

So, the FP8 qlinear shares the same op as INT8 qlinear and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library.

The differences of qlinear from `_scaled_mm` are that
- Qlinear supports post op fusion while `_scaled_mm` does not
- Weights are prepacked for qlinear

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k "qlinear and fp8"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155678
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-06-25 10:01:03 +00:00
committed by PyTorch MergeBot
parent 19ffb5e6f7
commit c2185dc4a5
3 changed files with 355 additions and 11 deletions

View File

@ -4511,7 +4511,7 @@ class TestQuantizedLinear(TestCase):
qlinear_op,
post_op="none",
unary_post_op_args=(),
post_op_algorithms=("none"),
post_op_algorithms=("none",),
):
qlinear_prepack = torch.ops.onednn.qlinear_prepack
linear_op = F.linear
@ -4678,6 +4678,184 @@ class TestQuantizedLinear(TestCase):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
quant_max = torch.finfo(torch.float8_e4m3fn).max
eps = torch.Tensor([torch.finfo(torch.float32).eps])
if channelwise:
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
scale = torch.max(scale, eps)
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
qt = t / scale_reshape
else:
scale = scale or t.abs().max().reshape([1]) / quant_max
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
qt = t / scale
qt = qt.to(torch.float8_e4m3fn)
return qt, scale
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
dqt = qt.float()
if scale.numel() == 1:
# per tensor
dqt = dqt * scale
else:
# per channel
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
dqt = dqt * scale_reshape
return dqt
def _test_qlinear_fp8_helper(
self,
qlinear_op,
post_op="none",
unary_post_op_args=(),
post_op_algorithms=("none",),
):
qlinear_prepack = torch.ops.onednn.qlinear_prepack
linear_op = F.linear
in_channels_list = [4, 8]
out_channels_list = [16, 32]
batch_size = 1
use_bias_list = [True, False]
weight_quant_per_channel_list = [True, False]
output_dtype_list = [None, torch.float32, torch.bfloat16]
y_scale, y_zp = 0.07, 0
input_dim_list = [2, 3]
cases = itertools.product(
in_channels_list, out_channels_list, use_bias_list,
weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list)
with override_quantized_engine('onednn'):
for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases:
used_y_scale = y_scale
used_y_zp = y_zp
fp32_out = output_dtype == torch.float32
bfloat16_out = output_dtype == torch.bfloat16
if fp32_out or bfloat16_out:
used_y_scale = 1.0
x2_scale, x2_zp = 1.0, 0
else:
x2_scale, x2_zp = 0.3, 0
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
w = torch.rand(oc, ic) * 10
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
if use_bias:
b = torch.rand(oc) * 10
else:
b = None
# compute reference result
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
y_ref = linear_op(x_ref, w_ref, b)
# compute fp8 linear
qw_packed = qlinear_prepack(qw, x.shape)
x_zp = 0
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
if post_op in ("none", "relu", "gelu"):
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
post_op, unary_post_op_args, post_op_algo
)
if post_op == "relu":
y_ref = F.relu(y_ref)
elif post_op == "gelu":
y_ref = F.gelu(y_ref, approximate=post_op_algo)
elif post_op in ("sum", "sum_relu"):
x2 = torch.rand_like(y_ref)
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
unary_post_op = "relu" if post_op == "sum_relu" else "none"
binary_alpha = 1.0 # we only support alpha=1.0 now
# if output_dtype is fp32 or bf16, accumulate on x2
# if output_dtype is None (fp8), accumulate on x2_dq
accum = x2_q if output_dtype is None else x2
accum_ref = x2_dq if output_dtype is None else x2.clone()
x2_scale = x2_scale if output_dtype is None else 1.0
if bfloat16_out:
accum = accum.bfloat16()
accum_ref = accum_ref.bfloat16()
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
accum, b, used_y_scale, used_y_zp, output_dtype,
x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + accum_ref * binary_alpha
if unary_post_op == "relu":
y_ref = F.relu(y_ref)
elif post_op in ("add", "add_relu"):
if output_dtype is not None:
# Only support fp8 output
continue
x2 = torch.rand_like(y_ref)
unary_post_op = "relu" if post_op == "add_relu" else "none"
binary_alpha = 1.0 # we only support alpha=1.0 now
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
x2, b, used_y_scale, used_y_zp, output_dtype,
1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + x2 * binary_alpha
if unary_post_op == "relu":
y_ref = F.relu(y_ref)
# Compare results
if output_dtype is None:
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
else:
y_ref = y_ref.to(output_dtype)
self.assertEqual(x.dim(), qy.dim())
self.assertEqual(y_ref.float(), qy.float())
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
self._test_qlinear_fp8_helper(qlinear, "none")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
self._test_qlinear_fp8_helper(qlinear, "relu")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_gelu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
post_op_algorithms = ['none', 'tanh']
self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_sum_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "sum")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_sum_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "sum_relu")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_add_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "add")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_add_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "add_relu")
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):