mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][CPU] Enable fp8 qlinear (#155678)
**Summary** Enable fp8 qlinear on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qlinear op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qlinear op is not changed either. So, the FP8 qlinear shares the same op as INT8 qlinear and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library. The differences of qlinear from `_scaled_mm` are that - Qlinear supports post op fusion while `_scaled_mm` does not - Weights are prepacked for qlinear **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k "qlinear and fp8" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155678 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
19ffb5e6f7
commit
c2185dc4a5
@ -4511,7 +4511,7 @@ class TestQuantizedLinear(TestCase):
|
||||
qlinear_op,
|
||||
post_op="none",
|
||||
unary_post_op_args=(),
|
||||
post_op_algorithms=("none"),
|
||||
post_op_algorithms=("none",),
|
||||
):
|
||||
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
||||
linear_op = F.linear
|
||||
@ -4678,6 +4678,184 @@ class TestQuantizedLinear(TestCase):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
|
||||
|
||||
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
|
||||
quant_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
eps = torch.Tensor([torch.finfo(torch.float32).eps])
|
||||
if channelwise:
|
||||
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
|
||||
scale = torch.max(scale, eps)
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
|
||||
qt = t / scale_reshape
|
||||
else:
|
||||
scale = scale or t.abs().max().reshape([1]) / quant_max
|
||||
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
|
||||
qt = t / scale
|
||||
qt = qt.to(torch.float8_e4m3fn)
|
||||
return qt, scale
|
||||
|
||||
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
|
||||
dqt = qt.float()
|
||||
if scale.numel() == 1:
|
||||
# per tensor
|
||||
dqt = dqt * scale
|
||||
else:
|
||||
# per channel
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
|
||||
dqt = dqt * scale_reshape
|
||||
return dqt
|
||||
|
||||
def _test_qlinear_fp8_helper(
|
||||
self,
|
||||
qlinear_op,
|
||||
post_op="none",
|
||||
unary_post_op_args=(),
|
||||
post_op_algorithms=("none",),
|
||||
):
|
||||
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
||||
linear_op = F.linear
|
||||
in_channels_list = [4, 8]
|
||||
out_channels_list = [16, 32]
|
||||
batch_size = 1
|
||||
use_bias_list = [True, False]
|
||||
weight_quant_per_channel_list = [True, False]
|
||||
output_dtype_list = [None, torch.float32, torch.bfloat16]
|
||||
y_scale, y_zp = 0.07, 0
|
||||
input_dim_list = [2, 3]
|
||||
cases = itertools.product(
|
||||
in_channels_list, out_channels_list, use_bias_list,
|
||||
weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list)
|
||||
with override_quantized_engine('onednn'):
|
||||
for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases:
|
||||
used_y_scale = y_scale
|
||||
used_y_zp = y_zp
|
||||
fp32_out = output_dtype == torch.float32
|
||||
bfloat16_out = output_dtype == torch.bfloat16
|
||||
if fp32_out or bfloat16_out:
|
||||
used_y_scale = 1.0
|
||||
x2_scale, x2_zp = 1.0, 0
|
||||
else:
|
||||
x2_scale, x2_zp = 0.3, 0
|
||||
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
|
||||
w = torch.rand(oc, ic) * 10
|
||||
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
|
||||
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
|
||||
if use_bias:
|
||||
b = torch.rand(oc) * 10
|
||||
else:
|
||||
b = None
|
||||
|
||||
# compute reference result
|
||||
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
|
||||
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
|
||||
y_ref = linear_op(x_ref, w_ref, b)
|
||||
|
||||
# compute fp8 linear
|
||||
qw_packed = qlinear_prepack(qw, x.shape)
|
||||
x_zp = 0
|
||||
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
|
||||
|
||||
if post_op in ("none", "relu", "gelu"):
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
if post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
elif post_op == "gelu":
|
||||
y_ref = F.gelu(y_ref, approximate=post_op_algo)
|
||||
elif post_op in ("sum", "sum_relu"):
|
||||
x2 = torch.rand_like(y_ref)
|
||||
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
|
||||
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
|
||||
unary_post_op = "relu" if post_op == "sum_relu" else "none"
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
# if output_dtype is fp32 or bf16, accumulate on x2
|
||||
# if output_dtype is None (fp8), accumulate on x2_dq
|
||||
accum = x2_q if output_dtype is None else x2
|
||||
accum_ref = x2_dq if output_dtype is None else x2.clone()
|
||||
x2_scale = x2_scale if output_dtype is None else 1.0
|
||||
if bfloat16_out:
|
||||
accum = accum.bfloat16()
|
||||
accum_ref = accum_ref.bfloat16()
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
accum, b, used_y_scale, used_y_zp, output_dtype,
|
||||
x2_scale, x2_zp, "sum", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + accum_ref * binary_alpha
|
||||
if unary_post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
elif post_op in ("add", "add_relu"):
|
||||
if output_dtype is not None:
|
||||
# Only support fp8 output
|
||||
continue
|
||||
x2 = torch.rand_like(y_ref)
|
||||
unary_post_op = "relu" if post_op == "add_relu" else "none"
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
x2, b, used_y_scale, used_y_zp, output_dtype,
|
||||
1.0, 0, "add", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + x2 * binary_alpha
|
||||
if unary_post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
|
||||
# Compare results
|
||||
if output_dtype is None:
|
||||
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
|
||||
else:
|
||||
y_ref = y_ref.to(output_dtype)
|
||||
|
||||
self.assertEqual(x.dim(), qy.dim())
|
||||
self.assertEqual(y_ref.float(), qy.float())
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
self._test_qlinear_fp8_helper(qlinear, "none")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
self._test_qlinear_fp8_helper(qlinear, "relu")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_gelu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
post_op_algorithms = ['none', 'tanh']
|
||||
self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_sum_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "sum")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_sum_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "sum_relu")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_add_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "add")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_add_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "add_relu")
|
||||
|
||||
|
||||
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
|
||||
class TestQuantizedEmbeddingOps(TestCase):
|
||||
|
Reference in New Issue
Block a user