[Inductor][Quant] Change the schema of QLinear Binary (#129049)

**Summary**
We change the schema of QLinear Binary, so it will be easier to enable the corresponding gemm template.

- Extra input of binary post-op is a tensor which needs to be an input node of autotuning, we need to move it at front of `output_scale` which is a scalar.
- We also move it at front of `bias`, since `bias` is optional tensor for this fusion, but `other` is a must to have for linear binary fusion.

**Test Plan**
```
python -u -m pytest -s -v test/quantization/core/test_quantized_op.py -k qlinear
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k qlinear
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129049
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825, #129048
This commit is contained in:
leslie-fang-intel
2024-06-27 18:28:42 -07:00
committed by PyTorch MergeBot
parent 07450e9713
commit 86e2d16ba0
7 changed files with 16 additions and 14 deletions

View File

@ -1254,11 +1254,11 @@ class QLinearOnednn final {
Tensor onednn_weight, // int8 tensor from MkldnnCPU
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<at::Tensor> other, // extra input for binary post-op
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
std::optional<at::Tensor> other, // extra input for binary post-op
double other_scale,
int64_t other_zero_point,
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
@ -1286,11 +1286,11 @@ class QLinearOnednn final {
Tensor onednn_weight, // int8 tensor from MkldnnCPU
Tensor weight_scales,
Tensor weight_zero_points,
std::optional<at::Tensor> other, // extra input for binary post-op
std::optional<Tensor> bias,
double output_scale,
int64_t output_zero_point,
std::optional<c10::ScalarType> output_dtype,
std::optional<at::Tensor> other, // extra input for binary post-op
double other_scale,
int64_t other_zero_point,
c10::string_view binary_post_op, // e.g. "none", "sum", "add"

View File

@ -272,6 +272,6 @@ TORCH_LIBRARY(onednn, m) {
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
// Linear with binary postop
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
}

View File

@ -141,6 +141,8 @@ ALLOW_LIST = [
("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)),
("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)),
("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)),
("onednn::qlinear_pointwise.binary", datetime.date(2024, 12, 31)),
("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)),
("aten::_scaled_mm.out", datetime.date(2024, 12, 31)),
("aten::_scaled_mm", datetime.date(2024, 12, 31)),
# BC-breaking change in can_cast signature: 'from' -> 'from_'

View File

@ -4332,8 +4332,8 @@ class TestQuantizedLinear(TestCase):
accum = accum.bfloat16()
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
accum, x2_scale, x2_zp, "sum", binary_alpha,
accum, b, used_y_scale, used_y_zp, output_dtype,
x2_scale, x2_zp, "sum", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + x2 * binary_alpha
@ -4350,8 +4350,8 @@ class TestQuantizedLinear(TestCase):
binary_alpha = 1.0 # we only support alpha=1.0 now
qy_cpu = qlinear_op(
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
b, used_y_scale, used_y_zp, output_dtype,
x2, 1.0, 0, "add", binary_alpha,
x2, b, used_y_scale, used_y_zp, output_dtype,
1.0, 0, "add", binary_alpha,
unary_post_op, unary_post_op_args, post_op_algo
)
y_ref = y_ref + x2 * binary_alpha

View File

@ -505,11 +505,11 @@ def _register_quantized_linear_binary_lowering(
packed_weight,
w_scale,
w_zp,
x2,
b,
o_inv_scale,
o_zero_point,
output_dtype,
x2,
x2_scale,
x2_zp,
binary_op_name,

View File

@ -1374,11 +1374,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
at::Tensor weight,
at::Tensor weight_scales,
at::Tensor weight_zero_points,
c10::optional<at::Tensor> other,
c10::optional<at::Tensor> bias,
double inv_output_scale,
int64_t output_zero_point,
c10::optional<c10::ScalarType> output_dtype,
c10::optional<at::Tensor> other,
double other_scale,
int64_t other_zero_point,
c10::string_view binary_post_op,
@ -1436,11 +1436,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
packed_weight,
w_scale,
w_zp,
other,
bias,
o_scale,
o_zp,
output_dtype,
other,
other_scale,
other_zp,
binary_attr,
@ -1470,11 +1470,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
qw: "TensorBox", # packed_weight
w_scale: "TensorBox",
w_zero_point: "TensorBox",
other: "TensorBox",
bias: "TensorBox",
output_scale: float,
output_zero_point: int,
output_dtype,
other: "TensorBox",
other_scale,
other_zp,
binary_post_op,

View File

@ -860,11 +860,11 @@ def register_onednn_fusion_ops():
packed_weight: TensorBox,
w_scale: TensorBox,
w_zp: TensorBox,
x2: TensorBox,
bias: TensorBox,
o_inv_scale,
o_zero_point,
output_dtype,
x2: TensorBox,
x2_scale,
x2_zp,
binary_attr,
@ -896,11 +896,11 @@ def register_onednn_fusion_ops():
packed_weight,
w_scale,
w_zp,
x2,
bias,
o_inv_scale,
o_zero_point,
output_dtype,
x2,
x2_scale,
x2_zp,
binary_attr,