mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Quant][CPU] Enable fp8 qlinear (#155678)
**Summary** Enable fp8 qlinear on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qlinear op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qlinear op is not changed either. So, the FP8 qlinear shares the same op as INT8 qlinear and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library. The differences of qlinear from `_scaled_mm` are that - Qlinear supports post op fusion while `_scaled_mm` does not - Weights are prepacked for qlinear **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k "qlinear and fp8" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/155678 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
committed by
PyTorch MergeBot
parent
19ffb5e6f7
commit
c2185dc4a5
@ -3,6 +3,7 @@
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/TensorOperators.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/core/List.h>
|
||||
#include <ATen/native/mkldnn/MKLDNNCommon.h>
|
||||
#include <ATen/native/quantized/PackedParams.h>
|
||||
#include <ATen/native/quantized/cpu/ACLUtils.h>
|
||||
@ -27,6 +28,14 @@
|
||||
#include <ATen/ops/quantize_per_tensor_native.h> // for quantize_per_te...
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/_weight_int4pack_mm_for_cpu.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/leaky_relu.h>
|
||||
#include <ATen/ops/tanh.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/hardtanh.h>
|
||||
#include <ATen/ops/hardswish.h>
|
||||
#include <ATen/ops/sigmoid.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
@ -918,6 +927,118 @@ at::Tensor PackedLinearWeightsOnednn:: apply_tanh(
|
||||
std::move(input), output_scale, output_zero_point);
|
||||
}
|
||||
|
||||
static at::Tensor fp8_qlinear_onednn_ref(
|
||||
at::Tensor input,
|
||||
double input_scale,
|
||||
at::Tensor weight, // expect plain weight
|
||||
at::Tensor weight_scales,
|
||||
std::optional<at::Tensor> bias, // plain tensor
|
||||
double output_scale,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
double other_scale,
|
||||
const std::string_view& binary_post_op, // e.g. "none", "sum", "add"
|
||||
double binary_alpha,
|
||||
const std::string_view& unary_post_op, // e.g. "none", "relu"
|
||||
torch::List<std::optional<at::Scalar>>& unary_post_op_args,
|
||||
std::string_view& unary_post_op_algorithm) {
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
||||
"FP8 qlinear: Unexpected dtype of input and weight:", input.scalar_type(), ", ", weight.scalar_type());
|
||||
const int64_t dim = input.dim();
|
||||
auto input_contig =
|
||||
dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
|
||||
auto N = weight.size(0);
|
||||
auto output_size = input.sizes().vec();
|
||||
output_size[dim - 1] = N;
|
||||
auto dqx = input_contig.to(at::kFloat) * input_scale;
|
||||
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
|
||||
w_scales_new_shape[0] = -1;
|
||||
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
|
||||
auto y_f32 = at::linear(dqx, dqw, bias);
|
||||
if (binary_post_op == "none") {
|
||||
if (unary_post_op == "relu") {
|
||||
at::relu_(y_f32);
|
||||
} else if (unary_post_op == "leaky_relu") {
|
||||
TORCH_CHECK(
|
||||
unary_post_op_args.size() == 1,
|
||||
"onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args");
|
||||
auto element = unary_post_op_args.get(0);
|
||||
auto alpha = element.value().to<float>();
|
||||
at::leaky_relu_(y_f32, alpha);
|
||||
} else if (unary_post_op == "tanh") {
|
||||
at::tanh_(y_f32);
|
||||
} else if (unary_post_op == "gelu") {
|
||||
TORCH_CHECK(
|
||||
unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh",
|
||||
"onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm);
|
||||
at::gelu_(y_f32, unary_post_op_algorithm);
|
||||
} else if (unary_post_op == "hardtanh") {
|
||||
TORCH_CHECK(
|
||||
unary_post_op_args.size() == 2 &&
|
||||
unary_post_op_args.get(0).has_value() &&
|
||||
unary_post_op_args.get(1).has_value(),
|
||||
"hardtanh is expected to have two scalar input: min_val and max_val");
|
||||
auto lower_bound_value =
|
||||
unary_post_op_args.get(0).value().to<float>();
|
||||
auto upper_bound_value =
|
||||
unary_post_op_args.get(1).value().to<float>();
|
||||
at::hardtanh_(y_f32, lower_bound_value, upper_bound_value);
|
||||
} else if (unary_post_op == "hardswish") {
|
||||
at::hardswish_(y_f32);
|
||||
} else if (unary_post_op == "swish") {
|
||||
// return ideep::attr_t::fuse_swish();
|
||||
y_f32 = y_f32 * at::sigmoid(y_f32);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
unary_post_op == "none",
|
||||
"onednn qlinear: unsupported unary post op ", unary_post_op);
|
||||
}
|
||||
} else if (binary_post_op == "sum") {
|
||||
TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum");
|
||||
auto x1 = other.value();
|
||||
TORCH_CHECK(x1.sizes().vec() == output_size);
|
||||
auto x1_f32 = x1.to(at::kFloat) * other_scale;
|
||||
x1_f32 = x1_f32.view(y_f32.sizes());
|
||||
if (unary_post_op == "none") {
|
||||
y_f32.add_(x1_f32);
|
||||
} else if (unary_post_op == "relu") {
|
||||
y_f32.add_(x1_f32).relu_();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum");
|
||||
}
|
||||
y_f32.div_(output_scale);
|
||||
x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes()));
|
||||
return x1;
|
||||
} else if (binary_post_op == "add") {
|
||||
TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum");
|
||||
auto x1 = other.value();
|
||||
TORCH_CHECK(x1.sizes().vec() == output_size);
|
||||
auto x1_f32 = x1.to(at::kFloat) * other_scale;
|
||||
x1_f32 = x1_f32.view(y_f32.sizes());
|
||||
if (unary_post_op == "none") {
|
||||
y_f32.add_(x1_f32);
|
||||
} else if (unary_post_op == "relu") {
|
||||
y_f32.add_(x1_f32).relu_();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add");
|
||||
}
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"onednn qlinear: unsupported binary post op ", binary_post_op);
|
||||
}
|
||||
|
||||
y_f32.div_(output_scale);
|
||||
y_f32 = y_f32.view(output_size);
|
||||
auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn;
|
||||
return y_f32.to(out_dtype);
|
||||
}
|
||||
|
||||
static at::Tensor linear_int8_with_onednn_weight(
|
||||
at::Tensor input, // int8 CPU Tensor, not QTensor
|
||||
double input_scale,
|
||||
@ -939,10 +1060,18 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
std::string_view& unary_post_op_algorithm) {
|
||||
using ideep::tensor;
|
||||
const int64_t dim = input.dim();
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
|
||||
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
|
||||
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
|
||||
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
|
||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char || input.scalar_type() == c10::ScalarType::Float8_e4m3fn,
|
||||
"qlinear with mkldnn tensor: data type of input should be uint8, int8 or float8_e4m3fn.");
|
||||
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
|
||||
"qlinear with mkldnn tensor: data type of weight should be int8 or float8_e4m3fn.");
|
||||
bool is_fp8 = false;
|
||||
if (input.scalar_type() == c10::ScalarType::Float8_e4m3fn || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn) {
|
||||
TORCH_CHECK(
|
||||
input.scalar_type() == c10::ScalarType::Float8_e4m3fn && onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
|
||||
"qlinear with mkldnn tensor: data type of input and weight should be the same for fp8, but got ",
|
||||
input.scalar_type(), " and ", onednn_weight.scalar_type());
|
||||
is_fp8 = true;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float.");
|
||||
TORCH_CHECK(
|
||||
@ -976,7 +1105,7 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
);
|
||||
}
|
||||
if (binary_post_op == "sum") {
|
||||
auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte;
|
||||
auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type();
|
||||
TORCH_CHECK(
|
||||
other.value().scalar_type() == expected_dtype,
|
||||
"onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype,
|
||||
@ -984,6 +1113,14 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
);
|
||||
}
|
||||
}
|
||||
if (is_fp8 && !cpuinfo_has_x86_amx_int8()) {
|
||||
// Fall back to ref impl on old platforms because not supported
|
||||
return fp8_qlinear_onednn_ref(
|
||||
input, input_scale, onednn_weight, weight_scales, bias,
|
||||
output_scale, output_dtype, other, other_scale,
|
||||
binary_post_op, binary_alpha, unary_post_op,
|
||||
unary_post_op_args, unary_post_op_algorithm);
|
||||
}
|
||||
|
||||
// If the input has more than two dimensions, we will reshape it to a 2-dimensional form
|
||||
// for calculation and subsequently reshape the output back.
|
||||
@ -1016,7 +1153,7 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
at::empty(
|
||||
dst_dims,
|
||||
at::device(c10::kCPU)
|
||||
.dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte))
|
||||
.dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : input.scalar_type()))
|
||||
);
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
@ -1029,7 +1166,7 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||
empty_tensor;
|
||||
|
||||
// Create onednn primitive
|
||||
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
|
||||
auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type());
|
||||
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
|
||||
auto weights_desc = packed_weight.get_desc();
|
||||
auto dst_dtype = dst.get_data_type();
|
||||
@ -1463,5 +1600,16 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
|
||||
TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"),
|
||||
TORCH_FN(QLinearOnednn::run_pointwise));
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"),
|
||||
TORCH_FN(at::native::QLinearOnednn::run_pointwise_tensor));
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"),
|
||||
TORCH_FN(QLinearOnednn::run_pointwise_binary));
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"),
|
||||
TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace at::native
|
||||
|
@ -297,14 +297,32 @@ c10::intrusive_ptr<LinearPackedParamsBase> PackedLinearWeightsOnednn::prepack(
|
||||
static inline at::Tensor pack_weight_to_onednn_tensor(
|
||||
const at::Tensor& weight,
|
||||
std::optional<torch::List<int64_t>>& input_shape) {
|
||||
at::ScalarType weigh_dtype = weight.scalar_type();
|
||||
TORCH_CHECK(
|
||||
weigh_dtype == at::kChar || weigh_dtype == at::kFloat8_e4m3fn,
|
||||
"Weight should be of type int8 or float8_e4m3fn");
|
||||
bool is_fp8 = weigh_dtype == at::kFloat8_e4m3fn;
|
||||
if (is_fp8 && !cpuinfo_has_x86_amx_int8()) {
|
||||
// oneDNN's fp8 requires AMX support
|
||||
// If AMX is not available, fall back to reference implementation
|
||||
return weight;
|
||||
}
|
||||
std::vector<int64_t> w_dims = weight.sizes().vec();
|
||||
ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr());
|
||||
auto w_data_type = is_fp8
|
||||
? dnnl::memory::data_type::f8_e4m3
|
||||
: dnnl::memory::data_type::s8;
|
||||
ideep::tensor wei = ideep::tensor({w_dims, w_data_type}, weight.data_ptr());
|
||||
wei.transpose_(0, 1); // oneDNN requires transposed weight
|
||||
ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims();
|
||||
ideep::attr_t op_attr;
|
||||
op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
if (!is_fp8) {
|
||||
op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0);
|
||||
}
|
||||
auto x_data_type = is_fp8
|
||||
? dnnl::memory::data_type::f8_e4m3
|
||||
: dnnl::memory::data_type::u8;
|
||||
auto w_desc = ideep::matmul_forward::expected_weights_desc(
|
||||
wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr);
|
||||
wei.get_dims(), input_dims, w_data_type, x_data_type, op_attr);
|
||||
ideep::tensor expected_weight(w_desc);
|
||||
expected_weight.feed_from(wei);
|
||||
auto packed_weight = at::native::new_with_itensor_mkldnn(
|
||||
|
@ -4511,7 +4511,7 @@ class TestQuantizedLinear(TestCase):
|
||||
qlinear_op,
|
||||
post_op="none",
|
||||
unary_post_op_args=(),
|
||||
post_op_algorithms=("none"),
|
||||
post_op_algorithms=("none",),
|
||||
):
|
||||
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
||||
linear_op = F.linear
|
||||
@ -4678,6 +4678,184 @@ class TestQuantizedLinear(TestCase):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
|
||||
|
||||
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
|
||||
quant_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
eps = torch.Tensor([torch.finfo(torch.float32).eps])
|
||||
if channelwise:
|
||||
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
|
||||
scale = torch.max(scale, eps)
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
|
||||
qt = t / scale_reshape
|
||||
else:
|
||||
scale = scale or t.abs().max().reshape([1]) / quant_max
|
||||
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
|
||||
qt = t / scale
|
||||
qt = qt.to(torch.float8_e4m3fn)
|
||||
return qt, scale
|
||||
|
||||
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
|
||||
dqt = qt.float()
|
||||
if scale.numel() == 1:
|
||||
# per tensor
|
||||
dqt = dqt * scale
|
||||
else:
|
||||
# per channel
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
|
||||
dqt = dqt * scale_reshape
|
||||
return dqt
|
||||
|
||||
def _test_qlinear_fp8_helper(
|
||||
self,
|
||||
qlinear_op,
|
||||
post_op="none",
|
||||
unary_post_op_args=(),
|
||||
post_op_algorithms=("none",),
|
||||
):
|
||||
qlinear_prepack = torch.ops.onednn.qlinear_prepack
|
||||
linear_op = F.linear
|
||||
in_channels_list = [4, 8]
|
||||
out_channels_list = [16, 32]
|
||||
batch_size = 1
|
||||
use_bias_list = [True, False]
|
||||
weight_quant_per_channel_list = [True, False]
|
||||
output_dtype_list = [None, torch.float32, torch.bfloat16]
|
||||
y_scale, y_zp = 0.07, 0
|
||||
input_dim_list = [2, 3]
|
||||
cases = itertools.product(
|
||||
in_channels_list, out_channels_list, use_bias_list,
|
||||
weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list)
|
||||
with override_quantized_engine('onednn'):
|
||||
for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases:
|
||||
used_y_scale = y_scale
|
||||
used_y_zp = y_zp
|
||||
fp32_out = output_dtype == torch.float32
|
||||
bfloat16_out = output_dtype == torch.bfloat16
|
||||
if fp32_out or bfloat16_out:
|
||||
used_y_scale = 1.0
|
||||
x2_scale, x2_zp = 1.0, 0
|
||||
else:
|
||||
x2_scale, x2_zp = 0.3, 0
|
||||
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
|
||||
w = torch.rand(oc, ic) * 10
|
||||
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
|
||||
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
|
||||
if use_bias:
|
||||
b = torch.rand(oc) * 10
|
||||
else:
|
||||
b = None
|
||||
|
||||
# compute reference result
|
||||
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
|
||||
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
|
||||
y_ref = linear_op(x_ref, w_ref, b)
|
||||
|
||||
# compute fp8 linear
|
||||
qw_packed = qlinear_prepack(qw, x.shape)
|
||||
x_zp = 0
|
||||
w_zps = torch.zeros_like(w_scales, dtype=torch.int)
|
||||
|
||||
if post_op in ("none", "relu", "gelu"):
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
if post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
elif post_op == "gelu":
|
||||
y_ref = F.gelu(y_ref, approximate=post_op_algo)
|
||||
elif post_op in ("sum", "sum_relu"):
|
||||
x2 = torch.rand_like(y_ref)
|
||||
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
|
||||
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
|
||||
unary_post_op = "relu" if post_op == "sum_relu" else "none"
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
# if output_dtype is fp32 or bf16, accumulate on x2
|
||||
# if output_dtype is None (fp8), accumulate on x2_dq
|
||||
accum = x2_q if output_dtype is None else x2
|
||||
accum_ref = x2_dq if output_dtype is None else x2.clone()
|
||||
x2_scale = x2_scale if output_dtype is None else 1.0
|
||||
if bfloat16_out:
|
||||
accum = accum.bfloat16()
|
||||
accum_ref = accum_ref.bfloat16()
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
accum, b, used_y_scale, used_y_zp, output_dtype,
|
||||
x2_scale, x2_zp, "sum", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + accum_ref * binary_alpha
|
||||
if unary_post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
elif post_op in ("add", "add_relu"):
|
||||
if output_dtype is not None:
|
||||
# Only support fp8 output
|
||||
continue
|
||||
x2 = torch.rand_like(y_ref)
|
||||
unary_post_op = "relu" if post_op == "add_relu" else "none"
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
qy = qlinear_op(
|
||||
qx, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
x2, b, used_y_scale, used_y_zp, output_dtype,
|
||||
1.0, 0, "add", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + x2 * binary_alpha
|
||||
if unary_post_op == "relu":
|
||||
y_ref = F.relu(y_ref)
|
||||
|
||||
# Compare results
|
||||
if output_dtype is None:
|
||||
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
|
||||
else:
|
||||
y_ref = y_ref.to(output_dtype)
|
||||
|
||||
self.assertEqual(x.dim(), qy.dim())
|
||||
self.assertEqual(y_ref.float(), qy.float())
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
self._test_qlinear_fp8_helper(qlinear, "none")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
self._test_qlinear_fp8_helper(qlinear, "relu")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_gelu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise
|
||||
post_op_algorithms = ['none', 'tanh']
|
||||
self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_sum_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "sum")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_sum_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "sum_relu")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_add_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "add")
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qlinear_add_relu_fp8(self):
|
||||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_fp8_helper(qlinear, "add_relu")
|
||||
|
||||
|
||||
@unittest.skipIf(IS_MACOS, "Known test failure on Mac.")
|
||||
class TestQuantizedEmbeddingOps(TestCase):
|
||||
|
Reference in New Issue
Block a user