mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Inductor][Quant] Change the schema of QLinear Binary (#129049)
**Summary** We change the schema of QLinear Binary, so it will be easier to enable the corresponding gemm template. - Extra input of binary post-op is a tensor which needs to be an input node of autotuning, we need to move it at front of `output_scale` which is a scalar. - We also move it at front of `bias`, since `bias` is optional tensor for this fusion, but `other` is a must to have for linear binary fusion. **Test Plan** ``` python -u -m pytest -s -v test/quantization/core/test_quantized_op.py -k qlinear python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k qlinear ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129049 Approved by: https://github.com/jgong5, https://github.com/jansel ghstack dependencies: #128825, #129048
This commit is contained in:
committed by
PyTorch MergeBot
parent
07450e9713
commit
86e2d16ba0
@ -1254,11 +1254,11 @@ class QLinearOnednn final {
|
||||
Tensor onednn_weight, // int8 tensor from MkldnnCPU
|
||||
Tensor weight_scales,
|
||||
Tensor weight_zero_points,
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
std::optional<Tensor> bias,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
@ -1286,11 +1286,11 @@ class QLinearOnednn final {
|
||||
Tensor onednn_weight, // int8 tensor from MkldnnCPU
|
||||
Tensor weight_scales,
|
||||
Tensor weight_zero_points,
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
std::optional<Tensor> bias,
|
||||
double output_scale,
|
||||
int64_t output_zero_point,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
std::optional<at::Tensor> other, // extra input for binary post-op
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op, // e.g. "none", "sum", "add"
|
||||
|
@ -272,6 +272,6 @@ TORCH_LIBRARY(onednn, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor"));
|
||||
// Linear with binary postop
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, Tensor? other, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.binary_tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? other, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, float other_scale, int other_zp, str binary_post_op, float binary_alpha, str unary_post_op, Scalar?[] unary_post_op_args, str unary_post_op_algorithm) -> Tensor"));
|
||||
}
|
||||
|
@ -141,6 +141,8 @@ ALLOW_LIST = [
|
||||
("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)),
|
||||
("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)),
|
||||
("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)),
|
||||
("onednn::qlinear_pointwise.binary", datetime.date(2024, 12, 31)),
|
||||
("onednn::qlinear_pointwise.binary_tensor", datetime.date(2024, 12, 31)),
|
||||
("aten::_scaled_mm.out", datetime.date(2024, 12, 31)),
|
||||
("aten::_scaled_mm", datetime.date(2024, 12, 31)),
|
||||
# BC-breaking change in can_cast signature: 'from' -> 'from_'
|
||||
|
@ -4332,8 +4332,8 @@ class TestQuantizedLinear(TestCase):
|
||||
accum = accum.bfloat16()
|
||||
qy_cpu = qlinear_op(
|
||||
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
accum, x2_scale, x2_zp, "sum", binary_alpha,
|
||||
accum, b, used_y_scale, used_y_zp, output_dtype,
|
||||
x2_scale, x2_zp, "sum", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + x2 * binary_alpha
|
||||
@ -4350,8 +4350,8 @@ class TestQuantizedLinear(TestCase):
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
qy_cpu = qlinear_op(
|
||||
qx_cpu, x_scale, x_zp, qw_packed, w_scales, w_zps,
|
||||
b, used_y_scale, used_y_zp, output_dtype,
|
||||
x2, 1.0, 0, "add", binary_alpha,
|
||||
x2, b, used_y_scale, used_y_zp, output_dtype,
|
||||
1.0, 0, "add", binary_alpha,
|
||||
unary_post_op, unary_post_op_args, post_op_algo
|
||||
)
|
||||
y_ref = y_ref + x2 * binary_alpha
|
||||
|
@ -505,11 +505,11 @@ def _register_quantized_linear_binary_lowering(
|
||||
packed_weight,
|
||||
w_scale,
|
||||
w_zp,
|
||||
x2,
|
||||
b,
|
||||
o_inv_scale,
|
||||
o_zero_point,
|
||||
output_dtype,
|
||||
x2,
|
||||
x2_scale,
|
||||
x2_zp,
|
||||
binary_op_name,
|
||||
|
@ -1374,11 +1374,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
|
||||
at::Tensor weight,
|
||||
at::Tensor weight_scales,
|
||||
at::Tensor weight_zero_points,
|
||||
c10::optional<at::Tensor> other,
|
||||
c10::optional<at::Tensor> bias,
|
||||
double inv_output_scale,
|
||||
int64_t output_zero_point,
|
||||
c10::optional<c10::ScalarType> output_dtype,
|
||||
c10::optional<at::Tensor> other,
|
||||
double other_scale,
|
||||
int64_t other_zero_point,
|
||||
c10::string_view binary_post_op,
|
||||
@ -1436,11 +1436,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
|
||||
packed_weight,
|
||||
w_scale,
|
||||
w_zp,
|
||||
other,
|
||||
bias,
|
||||
o_scale,
|
||||
o_zp,
|
||||
output_dtype,
|
||||
other,
|
||||
other_scale,
|
||||
other_zp,
|
||||
binary_attr,
|
||||
@ -1470,11 +1470,11 @@ class QLinearPointwiseBinaryPT2E(ExternKernelAlloc):
|
||||
qw: "TensorBox", # packed_weight
|
||||
w_scale: "TensorBox",
|
||||
w_zero_point: "TensorBox",
|
||||
other: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
output_scale: float,
|
||||
output_zero_point: int,
|
||||
output_dtype,
|
||||
other: "TensorBox",
|
||||
other_scale,
|
||||
other_zp,
|
||||
binary_post_op,
|
||||
|
@ -860,11 +860,11 @@ def register_onednn_fusion_ops():
|
||||
packed_weight: TensorBox,
|
||||
w_scale: TensorBox,
|
||||
w_zp: TensorBox,
|
||||
x2: TensorBox,
|
||||
bias: TensorBox,
|
||||
o_inv_scale,
|
||||
o_zero_point,
|
||||
output_dtype,
|
||||
x2: TensorBox,
|
||||
x2_scale,
|
||||
x2_zp,
|
||||
binary_attr,
|
||||
@ -896,11 +896,11 @@ def register_onednn_fusion_ops():
|
||||
packed_weight,
|
||||
w_scale,
|
||||
w_zp,
|
||||
x2,
|
||||
bias,
|
||||
o_inv_scale,
|
||||
o_zero_point,
|
||||
output_dtype,
|
||||
x2,
|
||||
x2_scale,
|
||||
x2_zp,
|
||||
binary_attr,
|
||||
|
Reference in New Issue
Block a user