[Quant][CPU] Enable fp8 qlinear (#155678)

**Summary**
Enable fp8 qlinear on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qlinear op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qlinear op is not changed either.

So, the FP8 qlinear shares the same op as INT8 qlinear and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library.

The differences of qlinear from `_scaled_mm` are that
- Qlinear supports post op fusion while `_scaled_mm` does not
- Weights are prepacked for qlinear

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k "qlinear and fp8"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155678
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen
2025-06-25 10:01:03 +00:00
committed by PyTorch MergeBot
parent 19ffb5e6f7
commit c2185dc4a5
3 changed files with 355 additions and 11 deletions

View File

@ -3,6 +3,7 @@
#include <ATen/Parallel.h>
#include <ATen/TensorOperators.h>
#include <ATen/core/Tensor.h>
#include <ATen/core/List.h>
#include <ATen/native/mkldnn/MKLDNNCommon.h>
#include <ATen/native/quantized/PackedParams.h>
#include <ATen/native/quantized/cpu/ACLUtils.h>
@ -27,6 +28,14 @@
#include <ATen/ops/quantize_per_tensor_native.h> // for quantize_per_te...
#include <ATen/ops/zeros.h>
#include <ATen/ops/_weight_int4pack_mm_for_cpu.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/leaky_relu.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/hardtanh.h>
#include <ATen/ops/hardswish.h>
#include <ATen/ops/sigmoid.h>
#endif
#include <c10/util/irange.h>
@ -918,6 +927,118 @@ at::Tensor PackedLinearWeightsOnednn:: apply_tanh(
std::move(input), output_scale, output_zero_point);
}
static at::Tensor fp8_qlinear_onednn_ref(
at::Tensor input,
double input_scale,
at::Tensor weight, // expect plain weight
at::Tensor weight_scales,
std::optional<at::Tensor> bias, // plain tensor
double output_scale,
std::optional<c10::ScalarType> output_dtype,
std::optional<at::Tensor> other, // extra input for binary post-op
double other_scale,
const std::string_view& binary_post_op, // e.g. "none", "sum", "add"
double binary_alpha,
const std::string_view& unary_post_op, // e.g. "none", "relu"
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
std::string_view& unary_post_op_algorithm) {
TORCH_CHECK(
input.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn,
"FP8 qlinear: Unexpected dtype of input and weight:", input.scalar_type(), ", ", weight.scalar_type());
const int64_t dim = input.dim();
auto input_contig =
dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
auto N = weight.size(0);
auto output_size = input.sizes().vec();
output_size[dim - 1] = N;
auto dqx = input_contig.to(at::kFloat) * input_scale;
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto y_f32 = at::linear(dqx, dqw, bias);
if (binary_post_op == "none") {
if (unary_post_op == "relu") {
at::relu_(y_f32);
} else if (unary_post_op == "leaky_relu") {
TORCH_CHECK(
unary_post_op_args.size() == 1,
"onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
auto element = unary_post_op_args.get(0);
auto alpha = element.value().to<float>();
at::leaky_relu_(y_f32, alpha);
} else if (unary_post_op == "tanh") {
at::tanh_(y_f32);
} else if (unary_post_op == "gelu") {
TORCH_CHECK(
unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
"onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
at::gelu_(y_f32, unary_post_op_algorithm);
} else if (unary_post_op == "hardtanh") {
TORCH_CHECK(
unary_post_op_args.size() == 2 &&
unary_post_op_args.get(0).has_value() &&
unary_post_op_args.get(1).has_value(),
"hardtanh is expected to have two scalar input: min_val and max_val");
auto lower_bound_value =
unary_post_op_args.get(0).value().to<float>();
auto upper_bound_value =
unary_post_op_args.get(1).value().to<float>();
at::hardtanh_(y_f32, lower_bound_value, upper_bound_value);
} else if (unary_post_op == "hardswish") {
at::hardswish_(y_f32);
} else if (unary_post_op == "swish") {
// return ideep::attr_t::fuse_swish();
y_f32 = y_f32 * at::sigmoid(y_f32);
} else {
TORCH_CHECK(
unary_post_op == "none",
"onednn qlinear: unsupported unary post op ", unary_post_op);
}
} else if (binary_post_op == "sum") {
TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum");
auto x1 = other.value();
TORCH_CHECK(x1.sizes().vec() == output_size);
auto x1_f32 = x1.to(at::kFloat) * other_scale;
x1_f32 = x1_f32.view(y_f32.sizes());
if (unary_post_op == "none") {
y_f32.add_(x1_f32);
} else if (unary_post_op == "relu") {
y_f32.add_(x1_f32).relu_();
} else {
TORCH_CHECK(
false,
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
}
y_f32.div_(output_scale);
x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes()));
return x1;
} else if (binary_post_op == "add") {
TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum");
auto x1 = other.value();
TORCH_CHECK(x1.sizes().vec() == output_size);
auto x1_f32 = x1.to(at::kFloat) * other_scale;
x1_f32 = x1_f32.view(y_f32.sizes());
if (unary_post_op == "none") {
y_f32.add_(x1_f32);
} else if (unary_post_op == "relu") {
y_f32.add_(x1_f32).relu_();
} else {
TORCH_CHECK(
false,
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
}
} else {
TORCH_CHECK(
false,
"onednn qlinear: unsupported binary post op ", binary_post_op);
}
y_f32.div_(output_scale);
y_f32 = y_f32.view(output_size);
auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn;
return y_f32.to(out_dtype);
}
static at::Tensor linear_int8_with_onednn_weight(
at::Tensor input, // int8 CPU Tensor, not QTensor
double input_scale,
@ -939,10 +1060,18 @@ static at::Tensor linear_int8_with_onednn_weight(
std::string_view& unary_post_op_algorithm) {
using ideep::tensor;
const int64_t dim = input.dim();
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char || input.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"qlinear with mkldnn tensor: data type of input should be uint8, int8 or float8_e4m3fn.");
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"qlinear with mkldnn tensor: data type of weight should be int8 or float8_e4m3fn.");
bool is_fp8 = false;
if (input.scalar_type() == c10::ScalarType::Float8_e4m3fn || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn) {
TORCH_CHECK(
input.scalar_type() == c10::ScalarType::Float8_e4m3fn && onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
"qlinear with mkldnn tensor: data type of input and weight should be the same for fp8, but got ",
input.scalar_type(), " and ", onednn_weight.scalar_type());
is_fp8 = true;
}
TORCH_CHECK(
weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float.");
TORCH_CHECK(
@ -976,7 +1105,7 @@ static at::Tensor linear_int8_with_onednn_weight(
);
}
if (binary_post_op == "sum") {
auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte;
auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type();
TORCH_CHECK(
other.value().scalar_type() == expected_dtype,
"onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype,
@ -984,6 +1113,14 @@ static at::Tensor linear_int8_with_onednn_weight(
);
}
}
if (is_fp8 && !cpuinfo_has_x86_amx_int8()) {
// Fall back to ref impl on old platforms because not supported
return fp8_qlinear_onednn_ref(
input, input_scale, onednn_weight, weight_scales, bias,
output_scale, output_dtype, other, other_scale,
binary_post_op, binary_alpha, unary_post_op,
unary_post_op_args, unary_post_op_algorithm);
}
// If the input has more than two dimensions, we will reshape it to a 2-dimensional form
// for calculation and subsequently reshape the output back.
@ -1016,7 +1153,7 @@ static at::Tensor linear_int8_with_onednn_weight(
at::empty(
dst_dims,
at::device(c10::kCPU)
.dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte))
.dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : input.scalar_type()))
);
if (output.numel() == 0) {
return output;
@ -1029,7 +1166,7 @@ static at::Tensor linear_int8_with_onednn_weight(
empty_tensor;
// Create onednn primitive
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type());
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
auto weights_desc = packed_weight.get_desc();
auto dst_dtype = dst.get_data_type();
@ -1463,5 +1600,16 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor));
}
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
TORCH_FN(QLinearOnednn::run_pointwise));
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
TORCH_FN(at::native::QLinearOnednn::run_pointwise_tensor));
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
TORCH_FN(QLinearOnednn::run_pointwise_binary));
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor));
}
} // namespace
} // namespace at::native

View File

@ -297,14 +297,32 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
static inline at::Tensor pack_weight_to_onednn_tensor(
const at::Tensor& weight,
std::optional<torch::List<int64_t>>& input_shape) {
at::ScalarType weigh_dtype = weight.scalar_type();
TORCH_CHECK(
weigh_dtype == at::kChar || weigh_dtype == at::kFloat8_e4m3fn,
"Weight should be of type int8 or float8_e4m3fn");
bool is_fp8 = weigh_dtype == at::kFloat8_e4m3fn;
if (is_fp8 && !cpuinfo_has_x86_amx_int8()) {
// oneDNN's fp8 requires AMX support
// If AMX is not available, fall back to reference implementation
return weight;
}
std::vector<int64_t> w_dims = weight.sizes().vec();
ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr());
auto w_data_type = is_fp8
? dnnl::memory::data_type::f8_e4m3
: dnnl::memory::data_type::s8;
ideep::tensor wei = ideep::tensor({w_dims, w_data_type}, weight.data_ptr());
wei.transpose_(0, 1); // oneDNN requires transposed weight
ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims();
ideep::attr_t op_attr;
op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
if (!is_fp8) {
op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
}
auto x_data_type = is_fp8
? dnnl::memory::data_type::f8_e4m3
: dnnl::memory::data_type::u8;
auto w_desc = ideep::matmul_forward::expected_weights_desc(
wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr);
wei.get_dims(), input_dims, w_data_type, x_data_type, op_attr);
ideep::tensor expected_weight(w_desc);
expected_weight.feed_from(wei);
auto packed_weight = at::native::new_with_itensor_mkldnn(

View File

@ -4511,7 +4511,7 @@ class TestQuantizedLinear(TestCase):
qlinear_op,
post_op="none",
unary_post_op_args=(),
post_op_algorithms=("none"),
post_op_algorithms=("none",),
):
qlinear_prepack = torch.ops.onednn.qlinear_prepack
linear_op = F.linear
@ -4678,6 +4678,184 @@ class TestQuantizedLinear(TestCase):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
quant_max = torch.finfo(torch.float8_e4m3fn).max
eps = torch.Tensor([torch.finfo(torch.float32).eps])
if channelwise:
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
scale = torch.max(scale, eps)
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
qt = t / scale_reshape
else:
scale = scale or t.abs().max().reshape([1]) / quant_max
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
qt = t / scale
qt = qt.to(torch.float8_e4m3fn)
return qt, scale
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
dqt = qt.float()
if scale.numel() == 1:
# per tensor
dqt = dqt * scale
else:
# per channel
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
dqt = dqt * scale_reshape
return dqt
def _test_qlinear_fp8_helper(
self,
qlinear_op,
post_op="none",
unary_post_op_args=(),
post_op_algorithms=("none",),
):
qlinear_prepack = torch.ops.onednn.qlinear_prepack
linear_op = F.linear
in_channels_list = [4, 8]
out_channels_list = [16, 32]
batch_size = 1
use_bias_list = [True, False]
weight_quant_per_channel_list = [True, False]
output_dtype_list = [None, torch.float32, torch.bfloat16]
y_scale, y_zp = 0.07, 0
input_dim_list = [2, 3]
cases = itertools.product(
in_channels_list, out_channels_list, use_bias_list,
weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list)
with override_quantized_engine('onednn'):
for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases:
used_y_scale = y_scale
used_y_zp = y_zp
fp32_out = output_dtype == torch.float32
bfloat16_out = output_dtype == torch.bfloat16
if fp32_out or bfloat16_out:
used_y_scale = 1.0
x2_scale, x2_zp = 1.0, 0
else:
x2_scale, x2_zp = 0.3, 0
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
w = torch.rand(oc, ic) * 10
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
if use_bias:
b = torch.rand(oc) * 10
else:
b = None
# compute reference result
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
y_ref = linear_op(x_ref, w_ref, b)
# compute fp8 linear
qw_packed = qlinear_prepack(qw, x.shape)
x_zp = 0
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
if post_op in ("none", "relu", "gelu"):
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
post_op, unary_post_op_args, post_op_algo
)
if post_op == "relu":
y_ref = F.relu(y_ref)
elif post_op == "gelu":
y_ref = F.gelu(y_ref, approximate=post_op_algo)
elif post_op in ("sum", "sum_relu"):
x2 = torch.rand_like(y_ref)
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
unary_post_op = "relu" if post_op == "sum_relu" else "none"
binary_alpha = 1.0 # we only support alpha=1.0 now
# if output_dtype is fp32 or bf16, accumulate on x2
# if output_dtype is None (fp8), accumulate on x2_dq
accum = x2_q if output_dtype is None else x2
accum_ref = x2_dq if output_dtype is None else x2.clone()
x2_scale = x2_scale if output_dtype is None else 1.0
if bfloat16_out:
accum = accum.bfloat16()
accum_ref = accum_ref.bfloat16()
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
accum, b, used_y_scale, used_y_zp, output_dtype,
x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + accum_ref * binary_alpha
if unary_post_op == "relu":
y_ref = F.relu(y_ref)
elif post_op in ("add", "add_relu"):
if output_dtype is not None:
# Only support fp8 output
continue
x2 = torch.rand_like(y_ref)
unary_post_op = "relu" if post_op == "add_relu" else "none"
binary_alpha = 1.0 # we only support alpha=1.0 now
qy = qlinear_op(
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
x2, b, used_y_scale, used_y_zp, output_dtype,
1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + x2 * binary_alpha
if unary_post_op == "relu":
y_ref = F.relu(y_ref)
# Compare results
if output_dtype is None:
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
else:
y_ref = y_ref.to(output_dtype)
self.assertEqual(x.dim(), qy.dim())
self.assertEqual(y_ref.float(), qy.float())
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
self._test_qlinear_fp8_helper(qlinear, "none")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
self._test_qlinear_fp8_helper(qlinear, "relu")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_gelu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise
post_op_algorithms = ['none', 'tanh']
self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_sum_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "sum")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_sum_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "sum_relu")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_add_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "add")
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qlinear_add_relu_fp8(self):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_fp8_helper(qlinear, "add_relu")
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
class TestQuantizedEmbeddingOps(TestCase):